Skip to content

[API] Update head_first parameter default to False#324

Merged
yzhangcs merged 4 commits intomainfrom
varlen
Apr 7, 2025
Merged

[API] Update head_first parameter default to False#324
yzhangcs merged 4 commits intomainfrom
varlen

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Apr 7, 2025

We will drop head_first in releases > 0.2

Summary by CodeRabbit

  • Refactor

    • Standardized the default processing order for tensor dimensions across core operations, now favoring a non-head-first (i.e., tail-first) approach for improved consistency.
  • Documentation

    • Updated all related documentation and error messages to clearly reflect the new default tensor format and shape expectations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

This pull request updates the default value of the head_first parameter from True to False across a wide range of functions in various modules. The changes adjust the interpretation of tensor shapes in both forward and backward operations and update associated documentation and error messages accordingly. In one case, a parameter was renamed for consistency. No new functionality is introduced; the modifications solely alter parameter defaults and documentation to reflect a revised processing order.

Changes

Files Change Summary
fla/ops/abc/chunk.py; fla/ops/based/fused_chunk.py, fla/ops/based/parallel.py; fla/ops/common/chunk_delta_h.py, fla/ops/common/chunk_h.py, fla/ops/common/chunk_h_parallel.py, fla/ops/common/chunk_h_split.py, fla/ops/common/chunk_o.py, fla/ops/common/fused_recurrent.py Updated default head_first from True to False in various chunk and recurrent functions; tensor shape documentation revised accordingly.
fla/ops/delta_rule/chunk.py, fla/ops/delta_rule/fused_recurrent.py, fla/ops/delta_rule/parallel.py; fla/ops/gated_delta_rule/chunk.py, fla/ops/gated_delta_rule/wy_fast.py Changed default head_first to False in delta rule and gated delta rule functions, with updated shape expectations and docs.
fla/ops/generalized_delta_rule/dplr/*, fla/ops/generalized_delta_rule/iplr/* Across DPLR/IPLR modules, modified default head_first to False; updated error messages, tensor shape notes, and renamed a parameter (offsetscu_seqlens).
fla/ops/gla/chunk.py, fla/ops/gla/fused_chunk.py, fla/ops/gla/fused_recurrent.py; fla/ops/gsa/chunk.py, fla/ops/gsa/fused_recurrent.py Adjusted default head_first to False in GLA and GSA operations, with corresponding updates to tensor shape descriptions.
fla/ops/lightning_attn/chunk.py, fla/ops/lightning_attn/fused_recurrent.py; fla/ops/linear_attn/chunk.py, fla/ops/linear_attn/fused_chunk.py, fla/ops/linear_attn/fused_recurrent.py Updated default head_first to False in attention functions; revised tensor dimensions and transposition logic documented.
fla/ops/nsa/naive.py, fla/ops/rebased/parallel.py; fla/ops/retention/chunk.py, fla/ops/retention/fused_chunk.py, fla/ops/retention/fused_recurrent.py, fla/ops/retention/parallel.py Changed head_first to False, reordering parameters or updating error messages and tensor shape descriptions in NSA, Rebased, and Retention modules.
fla/ops/rwkv6/*, fla/ops/rwkv7/* In RWKV6/7 operations, default head_first is now False with updated docstrings and error messaging to reflect the new tensor shape arrangements.
fla/ops/simple_gla/*, fla/ops/titans/naive.py, fla/ops/ttt/* Across Simple GLA, TTT, and Titans operations, default head_first changed to False with updated documentation to align with new tensor dimension orders.
fla/ops/utils/cumsum.py Modified default head_first to False in multiple cumsum functions; reformatted error messages for improved clarity.

Sequence Diagram(s)

Possibly related PRs

  • [TTT] Update fused chunk ops and state bias term #230: Updates addressing the head_first parameter’s default value across functions, similar to the modifications in this pull request, with an emphasis on tensor shape handling.
  • [WY representation] Faster lower triangle inverse #289: The changes in the main PR regarding the head_first parameter and its impact on tensor shapes are related to the modifications in the retrieved PR, which also involves handling the head_first parameter and its implications for tensor operations in the context of new functions.

Poem

I'm a little bunny in a code-filled glade,
Hopping through functions where defaults are remade.
head_first is now false, a twist in the tale,
Tensor shapes align on a smoother trail.
Documentation sings with clarity so bright,
As code hops forward in soft daylight.
Happy bytes and bouncy leaps—coding feels just right! 🐰✨


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a56294 and 1461f5b.

📒 Files selected for processing (38)
  • fla/ops/attn/parallel.py (6 hunks)
  • fla/ops/common/chunk_delta_h.py (9 hunks)
  • fla/ops/common/chunk_h.py (6 hunks)
  • fla/ops/common/chunk_h_parallel.py (12 hunks)
  • fla/ops/common/chunk_h_split.py (12 hunks)
  • fla/ops/common/chunk_o.py (12 hunks)
  • fla/ops/common/chunk_scaled_dot_kkt.py (3 hunks)
  • fla/ops/common/fused_recurrent.py (8 hunks)
  • fla/ops/delta_rule/fused_recurrent.py (8 hunks)
  • fla/ops/delta_rule/wy_fast.py (4 hunks)
  • fla/ops/forgetting_attn/parallel.py (7 hunks)
  • fla/ops/gated_delta_rule/fused_recurrent.py (3 hunks)
  • fla/ops/gated_delta_rule/wy_fast.py (9 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py (5 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py (5 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py (3 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py (3 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py (9 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py (3 hunks)
  • fla/ops/generalized_delta_rule/dplr/fused_recurrent.py (4 hunks)
  • fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py (2 hunks)
  • fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py (7 hunks)
  • fla/ops/generalized_delta_rule/iplr/chunk.py (11 hunks)
  • fla/ops/generalized_delta_rule/iplr/fused_recurrent.py (8 hunks)
  • fla/ops/generalized_delta_rule/iplr/wy_fast.py (7 hunks)
  • fla/ops/gla/chunk.py (29 hunks)
  • fla/ops/gsa/chunk.py (20 hunks)
  • fla/ops/hgrn/fused_recurrent.py (4 hunks)
  • fla/ops/nsa/compression.py (6 hunks)
  • fla/ops/nsa/parallel.py (8 hunks)
  • fla/ops/rwkv6/chunk.py (26 hunks)
  • fla/ops/rwkv6/fused_recurrent.py (14 hunks)
  • fla/ops/simple_gla/parallel.py (8 hunks)
  • fla/ops/ttt/chunk.py (24 hunks)
  • fla/ops/ttt/fused_chunk.py (9 hunks)
  • fla/ops/utils/cumsum.py (16 hunks)
  • fla/ops/utils/pooling.py (4 hunks)
  • fla/ops/utils/solve_tril.py (6 hunks)
✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

🔭 Outside diff range comments (1)
fla/ops/delta_rule/parallel.py (1)

326-328: ⚠️ Potential issue

Docstring needs update to match the new default

The docstring still indicates Default: 'True' on line 327, which contradicts the actual default which is now False. This needs to be updated to maintain consistency.

Apply this change:

        head_first (Optional[bool]):
            Whether the inputs are in the head-first format.
-           Default: `True`.
+           Default: `False`.
🧹 Nitpick comments (2)
fla/ops/retention/fused_recurrent.py (1)

22-23: Code refactored for better readability

The code has been refactored to use a single variable H for the number of heads based on the head_first parameter, making the code more concise and readable.

Consider adding a brief comment explaining what H represents (number of heads) for clarity.

fla/ops/nsa/naive.py (1)

49-53: Consider combining nested if statements

Static analysis suggests combining the nested if statements to improve readability.

-    if cu_seqlens is not None:
-        if head_first:
-            raise RuntimeError(
-                "Sequences with variable lengths are not supported for head-first mode"
-            )
+    if cu_seqlens is not None and head_first:
+        raise RuntimeError(
+            "Sequences with variable lengths are not supported for head-first mode"
+        )
🧰 Tools
🪛 Ruff (0.8.2)

49-50: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1a8367b and 4205b0f.

📒 Files selected for processing (54)
  • fla/ops/abc/chunk.py (2 hunks)
  • fla/ops/based/fused_chunk.py (1 hunks)
  • fla/ops/based/parallel.py (1 hunks)
  • fla/ops/common/chunk_delta_h.py (2 hunks)
  • fla/ops/common/chunk_h.py (2 hunks)
  • fla/ops/common/chunk_h_parallel.py (2 hunks)
  • fla/ops/common/chunk_h_split.py (2 hunks)
  • fla/ops/common/chunk_o.py (4 hunks)
  • fla/ops/common/fused_recurrent.py (4 hunks)
  • fla/ops/delta_rule/chunk.py (3 hunks)
  • fla/ops/delta_rule/fused_recurrent.py (4 hunks)
  • fla/ops/delta_rule/parallel.py (2 hunks)
  • fla/ops/gated_delta_rule/chunk.py (3 hunks)
  • fla/ops/gated_delta_rule/wy_fast.py (1 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk.py (5 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py (1 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py (1 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py (1 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py (1 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py (3 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py (1 hunks)
  • fla/ops/generalized_delta_rule/dplr/fused_recurrent.py (2 hunks)
  • fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py (1 hunks)
  • fla/ops/generalized_delta_rule/iplr/chunk.py (7 hunks)
  • fla/ops/generalized_delta_rule/iplr/fused_recurrent.py (3 hunks)
  • fla/ops/generalized_delta_rule/iplr/wy_fast.py (1 hunks)
  • fla/ops/gla/chunk.py (11 hunks)
  • fla/ops/gla/fused_chunk.py (1 hunks)
  • fla/ops/gla/fused_recurrent.py (3 hunks)
  • fla/ops/gsa/chunk.py (10 hunks)
  • fla/ops/gsa/fused_recurrent.py (7 hunks)
  • fla/ops/lightning_attn/chunk.py (2 hunks)
  • fla/ops/lightning_attn/fused_recurrent.py (2 hunks)
  • fla/ops/linear_attn/chunk.py (2 hunks)
  • fla/ops/linear_attn/fused_chunk.py (2 hunks)
  • fla/ops/linear_attn/fused_recurrent.py (1 hunks)
  • fla/ops/nsa/naive.py (2 hunks)
  • fla/ops/rebased/parallel.py (1 hunks)
  • fla/ops/retention/chunk.py (2 hunks)
  • fla/ops/retention/fused_chunk.py (2 hunks)
  • fla/ops/retention/fused_recurrent.py (1 hunks)
  • fla/ops/retention/parallel.py (2 hunks)
  • fla/ops/rwkv6/chunk.py (10 hunks)
  • fla/ops/rwkv6/fused_recurrent.py (6 hunks)
  • fla/ops/rwkv7/chunk.py (2 hunks)
  • fla/ops/rwkv7/fused_recurrent.py (2 hunks)
  • fla/ops/simple_gla/chunk.py (5 hunks)
  • fla/ops/simple_gla/fused_recurrent.py (3 hunks)
  • fla/ops/simple_gla/parallel.py (3 hunks)
  • fla/ops/titans/naive.py (1 hunks)
  • fla/ops/ttt/chunk.py (12 hunks)
  • fla/ops/ttt/fused_chunk.py (7 hunks)
  • fla/ops/ttt/naive.py (1 hunks)
  • fla/ops/utils/cumsum.py (7 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
fla/ops/nsa/naive.py

49-50: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (184)
fla/ops/rebased/parallel.py (1)

450-450: Default parameter value change is consistent with PR objectives.

The change of the default value for head_first from True to False aligns with the PR objectives. This affects the interpretation of tensor shapes. When head_first=False (new default), the code will transpose tensors at lines 458 and 465, assuming input in [B, T, H, K] format rather than [B, H, T, K].

fla/ops/titans/naive.py (1)

316-316: Parameter default updated for consistent API across modules.

The change of the default value for head_first from True to False aligns with the PR objectives. When head_first=False (new default), the code will transpose tensors at lines 321-327 and 373-374, changing the expected input format to [B, T, H, ...] instead of [B, H, T, ...].

fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py (1)

134-134: API parameter default change matches PR objective.

The default value for head_first has been changed from True to False, which is consistent with the PR objectives. This impacts how tensor shapes are interpreted at lines 137-140 and 165-168, changing the expected default format from [B, H, T, K, V] to [B, T, H, K, V].

fla/ops/generalized_delta_rule/iplr/wy_fast.py (1)

256-256: Default parameter value updated for API consistency.

The default value for head_first has been changed from True to False, which aligns with the PR objectives. This affects tensor shape interpretation at lines 259-262 and 268, changing the expected default format from [B, H, T, K] to [B, T, H, K].

fla/ops/gla/fused_chunk.py (1)

611-620:

Details

✅ Verification successful

Verified change of head_first default from True to False

This change reverses the default behavior of tensor dimension handling. Now q, k, v, g tensors will be transposed (lines 623-624) by default, and the output o will also be transposed (lines 629-630) by default.

Make sure all calling code is updated to expect this new behavior or explicitly sets head_first=True where the old behavior is required.


🏁 Script executed:

#!/bin/bash
# Find and check usages of fused_chunk_gla to ensure they're compatible with the new default

# Check for direct calls to fused_chunk_gla
rg "fused_chunk_gla\s*\(" -A 10

Length of output: 3987


Confirmed default behavior update for head_first.
The grep results show that all explicit calls in fla/layers/hgrn2.py and fla/layers/gla.py are already passing head_first=False, and the benchmarks correctly rely on the new default. Please ensure that any new invocations needing the previous behavior continue to explicitly set head_first=True.

fla/ops/linear_attn/fused_recurrent.py (1)

232-241: Verified change of head_first default from True to False

This change aligns with the other changes in this PR. Now q, k, v tensors will be transposed (lines 244-245) by default, and the output o will also be transposed (lines 249-250) by default.

This provides consistent behavior across all related functions, improving API consistency.

fla/ops/common/chunk_h_split.py (2)

490-502: Verified change of head_first default from True to False

This change modifies the default tensor handling behavior. With this change, tensor shapes are handled differently by default (see lines 506-509), which ensures consistent behavior with other similar functions that had the same change.


593-597: Verified change of head_first default from True to False

This parameter change aligns with the previous function in this file and with other similar functions across the codebase. The function correctly handles both cases through conditional logic in lines 575-580 and 594-597 where different tensor shapes are created based on the value of this parameter.

fla/ops/common/chunk_h_parallel.py (2)

480-492: Verified change of head_first default from True to False

This change is consistent with other similar changes in this PR. The function properly accounts for both cases with the conditionals at lines 494-497 and 509 where tensor shapes are handled differently based on this parameter.


558-574: Verified change of head_first default from True to False

Similar to the previous function change, this maintains consistency with the API changes throughout the PR. The function properly handles the parameter with conditionals at lines 575-580 and 594-597.

fla/ops/retention/chunk.py (1)

25-25: Documentation updated to match new default value

The function documentation has been properly updated to reflect the new default value of head_first=False. All shape descriptions now list the head_first=False format first, which aligns with the updated parameter default.

Also applies to: 27-27, 29-29, 44-44, 48-48

fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py (1)

372-372: Default parameter change looks good

The default value for the head_first parameter has been changed from True to False, which is consistent with the PR objective. This changes the default tensor shape interpretation from (B, H, T, K) to (B, T, H, K).

fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py (1)

241-241: Default parameter change is consistent

The default value for the head_first parameter has been changed from True to False, which aligns with the PR objective. This changes the default tensor shape interpretation from (B, H, T, K) to (B, T, H, K).

fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py (1)

109-109: Default parameter change is consistent

The default value for the head_first parameter has been changed from True to False, which aligns with the PR objective. This changes the default tensor shape interpretation from (B, H, T, K, V) to (B, T, H, K, V).

fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py (3)

312-312: Default parameter change is consistent

The default value for the head_first parameter has been changed from True to False in the chunk_dplr_bwd_dv function, which aligns with the PR objective.


362-362: Default parameter change is consistent

The default value for the head_first parameter has been changed from True to False in the chunk_dplr_bwd_o function, which aligns with the PR objective.


423-423: Default parameter change is consistent

The default value for the head_first parameter has been changed from True to False in the chunk_dplr_bwd_dAu function, which aligns with the PR objective.

fla/ops/common/chunk_delta_h.py (2)

428-429: Parameter defaults updated to reflect new behavior

The default value of head_first has been changed from True to False, which alters how tensor dimensions are interpreted. This change affects how offsets, strides, and tensor expansions are calculated throughout the function.


493-494: Parameter defaults updated to reflect new behavior

The default value of head_first has been changed from True to False, consistent with the forward function. The comment about removing the chunk_size argument also indicates a possible future refactoring to simplify the API.

fla/ops/retention/parallel.py (3)

18-18: Default behavior changed from head-first to tail-first

The default value of head_first parameter has been changed from True to False, impacting how tensor dimensions are interpreted by default.


23-27: Documentation updated to reflect new tensor shape expectations

The documentation has been updated to clearly describe the expected tensor shapes based on the head_first parameter value. The new default behavior (head_first=False) puts the head dimension after the sequence dimension.


41-41: Output tensor shape documentation updated

The documentation for the output tensor shape has been updated to match the new default behavior, ensuring consistency with the input parameters.

fla/ops/retention/fused_recurrent.py (1)

20-20: Default behavior changed from head-first to tail-first

The default value of head_first parameter has been changed from True to False, consistent with the changes in other files.

fla/ops/generalized_delta_rule/iplr/fused_recurrent.py (4)

409-409: Parameter renamed for API consistency

The parameter offsets has been renamed to cu_seqlens to maintain consistency with the FlashAttention API and better reflect its purpose (cumulative sequence lengths).


433-436: Documentation updated for renamed parameter

Documentation has been updated to reflect the renamed parameter and clearly explain its purpose as cumulative sequence lengths used for variable-length training.


438-452: Validation logic updated for renamed parameter

The validation checks have been updated to reference cu_seqlens instead of offsets, with improved error messages that clearly explain the constraints.


458-458: Function call updated to use renamed parameter

Updated the function call to pass cu_seqlens instead of offsets to maintain consistency with the parameter renaming.

fla/ops/retention/fused_chunk.py (3)

329-329: Approval: Change to default parameter value looks good

The default value of head_first has been changed from True to False, which aligns with the PR objective of updating this parameter's default across the codebase.


334-338: Documentation correctly updated to reflect new tensor shape expectations

The docstring has been properly updated to show the expected tensor shapes based on the head_first parameter value, with the head_first=False case now listed first as it's the new default.


347-348: Documentation updated consistently for parameter description and return shape

The parameter description and return shape documentation have been updated to match the new default value of head_first=False.

Also applies to: 351-352

fla/ops/nsa/naive.py (2)

23-29: Documentation updated to reflect tensor shape expectations

The docstring has been properly updated to show the expected tensor shapes based on the head_first parameter value, with the head_first=False case now listed first as it's the new default.

Also applies to: 40-45


17-18:

Details

✅ Verification successful

Parameter order changed and default value updated

The head_first parameter has been moved after cu_seqlens and its default value has been changed from True to False, which aligns with the PR objective.

Note: This parameter reordering may affect code that relies on positional arguments rather than keyword arguments.

Run the following script to check for potential call sites that might be affected by the parameter reordering:


🏁 Script executed:

#!/bin/bash
# Search for call sites of naive_nsa function that might be affected by parameter reordering
rg -A 5 -B 5 "naive_nsa\(" --type py | grep -v "def naive_nsa"

Length of output: 1663


Parameter Ordering Verified: Changes are Safe
The updated parameter ordering (with head_first moved after cu_seqlens and its default set to False) aligns with the PR objective. Verification of call sites in the test suite shows that all invocations of naive_nsa use keyword arguments exclusively, so no issues with positional parameters were detected. Nonetheless, if any external code still relies on positional arguments, it may be affected.

fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py (1)

285-285: Approval: Default parameter value changed as expected

The default value of head_first has been changed from True to False, which aligns with the PR objective of updating this parameter's default across the codebase.

fla/ops/gated_delta_rule/chunk.py (3)

29-29: Approval: Default parameter value changed as expected

The default value of head_first in the chunk_gated_delta_rule_fwd function has been changed from True to False, which aligns with the PR objective.


88-88: Approval: Default parameter value changed as expected

The default value of head_first in the chunk_gated_delta_rule_bwd function has been changed from True to False, which aligns with the PR objective.


196-196: Approval: Default parameter value changed as expected

The default value of head_first in the forward method of ChunkGatedDeltaRuleFunction has been changed from True to False, which aligns with the PR objective.

fla/ops/simple_gla/fused_recurrent.py (5)

70-74: Example updated to use the new default

The example code has been properly updated to no longer specify the head_first parameter explicitly, which is good for consistency with the new default value.


79-84: Example updated to use the new default

The variable-length example code has been properly updated to no longer specify the head_first parameter explicitly, which is good for consistency with the new default value.


26-32: Documentation updated to reflect the new parameter default

The documentation for input tensors has been properly updated to describe shapes with head_first=False first, followed by the head_first=True case, which aligns well with the new default.


54-54: Documentation updated to reflect the new parameter default

The output shape documentation has been properly updated to describe the shape with head_first=False first, followed by the head_first=True case, which aligns well with the new default.


90-102: Error messages reformatted for better readability

The error messages have been reformatted with proper line breaks and f-strings, improving readability while maintaining the same error conditions.

fla/ops/gla/fused_recurrent.py (3)

27-35: Documentation updated to reflect the new parameter default

The documentation for input tensors has been properly updated to describe shapes with head_first=False first, followed by the head_first=True case, which aligns well with the new default.


56-56: Documentation updated to reflect the new parameter default

The output shape documentation has been properly updated to describe the shape with head_first=False first, followed by the head_first=True case, which aligns well with the new default.


90-102: Error messages reformatted for better readability

The error messages have been reformatted with proper line breaks and f-strings, improving readability while maintaining the same error conditions.

fla/ops/delta_rule/parallel.py (3)

313-319: Documentation updated to reflect the new parameter default

The documentation for input tensors has been properly updated to describe shapes with head_first=False first, followed by the head_first=True case, which aligns well with the new default.


331-331: Documentation updated to reflect the new parameter default

The output shape documentation has been properly updated to describe the shape with head_first=False first, followed by the head_first=True case, which aligns well with the new default.


308-308:

Details

❓ Verification inconclusive

Default value changed from True to False

The default value of head_first parameter has been changed from True to False. This is a breaking change for code that relies on the default behavior without explicitly specifying this parameter.


🏁 Script executed:

#!/bin/bash
# Identify potential usage of parallel_delta_rule without explicitly specifying head_first
rg "parallel_delta_rule\s*\(" --type py | grep -v "head_first\s*="

Length of output: 124


Attention: Confirm Impact of Default Parameter Change

The default value of head_first has indeed been changed from True to False in fla/ops/delta_rule/parallel.py (line 308). Our verification search for calls to parallel_delta_rule that omit an explicit head_first parameter returned only the function definition, and no additional call sites were found that rely on the default value.

  • Observation: The automated search did not reveal any invocations using the default value implicitly.
  • Caution: Although no caller appears affected from our grep result, please verify that there are no dynamic or indirect usages (such as callbacks or higher-order functions) that might rely on the previous default.

Please review surrounding usage contexts to ensure that this breaking change was intentional and does not introduce unexpected behavior.

fla/ops/delta_rule/fused_recurrent.py (6)

530-532: Documentation updated to reflect the new parameter default

The documentation for the head_first parameter has been properly updated to reflect its new default value of False.


510-514: Documentation updated to reflect the new parameter default

The input tensor shape documentation has been properly formatted to list the head_first=False case first, followed by the head_first=True case, which aligns well with the new default.


535-535: Documentation updated to reflect the new parameter default

The output tensor shape documentation has been properly formatted to list the head_first=False case first, followed by the head_first=True case, which aligns well with the new default.


551-555: Example updated to use the new default

The example code has been properly updated to no longer specify the head_first parameter explicitly, which is good for consistency with the new default value.


560-565: Example updated to use the new default

The variable-length example code has been properly updated to no longer specify the head_first parameter explicitly, which is good for consistency with the new default value.


441-441:

Details

❓ Verification inconclusive

Default value changed from True to False

The default value of head_first parameter for the forward method in FusedRecurrentFunction has been changed from True to False. This is a breaking change for code that relies on the default behavior without explicitly specifying this parameter.


🏁 Script executed:

#!/bin/bash
# Identify potential usage of FusedRecurrentFunction.apply without explicitly specifying head_first as the 9th parameter
rg "FusedRecurrentFunction\.apply\s*\(" --type py

Length of output: 310


🏁 Script executed:

#!/bin/bash
# Check if any explicit "head_first" argument is passed in usages of FusedRecurrentFunction.apply
rg "FusedRecurrentFunction\.apply\s*\(.*head_first=" --type py

Length of output: 64


Attention: Breaking Change Detected in FusedRecurrentFunction

The default value for the head_first parameter in the forward method has been changed from True to False, introducing a breaking change. Our investigation shows that all calls to FusedRecurrentFunction.apply in the following files do not specify the head_first parameter explicitly:

  • fla/ops/common/fused_recurrent.py
  • fla/ops/delta_rule/fused_recurrent.py
  • fla/ops/gated_delta_rule/fused_recurrent.py

Since these call sites rely on the default behavior, this change might lead to unintended behavior in downstream code. Please verify whether this change is intentional. If the previous behavior is required, consider either reverting the default to True or updating these call sites to explicitly pass head_first=True.

fla/ops/generalized_delta_rule/dplr/chunk.py (4)

34-34: Default parameter value changed for consistency.

The default value of head_first has been changed from True to False, which is consistent with the PR objective to update this parameter default across various modules.


119-119: Parameter default updated consistently.

The default value change from True to False for head_first matches the changes made in the parent function.


328-339: Documentation updated correctly to reflect new default behavior.

The tensor shape descriptions have been properly updated to reflect the new head_first=False default condition.


367-379: Error messages reformatted for better readability.

The error messages have been reformatted with improved multiline structure while maintaining the same logic and error conditions.

fla/ops/simple_gla/chunk.py (5)

25-25: Default parameter value changed for consistency.

The default value of head_first has been changed from True to False, aligning with the PR objective.


68-68: Parameter default updated in backward implementation.

The head_first parameter default value has been correctly updated in the backward implementation to maintain consistency with the forward implementation.


218-218: Default parameter value changed in public function.

The head_first parameter default value has been updated in the main exported function, ensuring the change is propagated to users of this API.


223-229: Documentation updated to reflect new default behavior.

The tensor shape descriptions in the docstrings have been properly updated to match the new head_first=False default value.

Also applies to: 245-245, 249-249


284-296: Error messages reformatted for improved readability.

The error messages have been reformatted with better multiline structure for clarity while maintaining the same validation logic.

fla/ops/delta_rule/chunk.py (3)

28-28: Default parameter value changed for consistency.

The default value of head_first has been changed from True to False, aligning with the PR objective.


83-83: Parameter default updated in backward implementation.

The head_first parameter default value has been correctly updated in the backward implementation to maintain consistency with the forward implementation.


184-184: Parameter default updated in autograd function.

The head_first parameter default value change has been properly propagated to the autograd function implementation.

fla/ops/generalized_delta_rule/dplr/fused_recurrent.py (2)

131-131: Default parameter value changed for consistency.

The default value of head_first has been changed from True to False in the forward implementation, aligning with the PR objective.


267-279: Error messages reformatted for improved readability.

The error messages have been reformatted with better multiline structure for clarity while maintaining the same validation logic.

fla/ops/common/fused_recurrent.py (4)

335-335: Parameter default value updated appropriately.

The default value of head_first parameter has been changed from True to False, which will now make the function expect input tensors in the shape of [B, T, H, K] by default instead of [B, H, T, K].


396-396: Parameter default value updated appropriately.

Default for head_first changed from True to False in the backward function, consistent with the change in the forward function.


490-490: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the forward method of FusedRecurrentFunction class, maintaining consistency with the standalone functions.


558-558: Parameter default value updated appropriately.

The default value of head_first parameter in the public-facing fused_recurrent function has been changed from True to False, keeping it consistent with the internal implementation.

fla/ops/gsa/fused_recurrent.py (8)

96-96: Parameter default value updated appropriately.

The default value of head_first parameter has been changed from True to False in the fused_recurrent_gsa_inference function.


153-153: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the fused_recurrent_gsa_fwd function to maintain consistency with other changes.


253-253: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the fused_recurrent_gsa_bwd function, maintaining API consistency.


387-387: Parameter default value updated appropriately.

The default value of head_first parameter has been updated in the forward method of FusedRecurrentGSAFunction.


469-469: Parameter default value updated appropriately.

The default value of head_first parameter has been changed from True to False in the public-facing fused_recurrent_gsa function.


474-474: Documentation updated to reflect the new default parameter value.

The docstring has been updated to correctly show that when head_first=False (now the default), the tensor shapes should be [B, T, H, K].

Also applies to: 476-476, 478-478


500-500: Documentation default value updated.

The docstring has been updated to show that the default value for head_first is now False.


542-544: Error messages reformatted and updated.

The error messages have been reformatted and updated to maintain consistency with the parameter default change.

Also applies to: 547-549, 551-554

fla/ops/common/chunk_o.py (4)

465-465: Parameter default value updated appropriately.

The default value of head_first parameter has been changed from True to False in the chunk_fwd_o function.


509-509: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_bwd_dv function to maintain consistency.


562-562: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_bwd_dv_local function to maintain API consistency.


618-618: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_bwd_dqkwg function to maintain consistency with other functions.

fla/ops/gla/chunk.py (13)

853-853: Parameter default value updated appropriately.

The default value of head_first parameter has been changed from True to False in the chunk_gla_fwd_intra_gk function.


956-956: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_gla_fwd_o_gk function.


994-994: Parameter default value updated appropriately.

Default for head_first changed from True to False in the chunk_gla_bwd_dA function.


1032-1032: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_gla_bwd_dv function.


1070-1070: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_gla_bwd_dqk_intra function.


1121-1121: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_gla_bwd_dqkg function.


1173-1173: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_gla_fwd function.


1236-1236: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the chunk_gla_bwd function.


1412-1412: Parameter default value updated appropriately.

The default value of head_first parameter has been changed in the public-facing chunk_gla function.


1417-1417: Documentation updated to reflect the new default parameter value.

The docstring has been updated to correctly show that when head_first=False (now the default), the tensor shapes should be [B, T, H, K].

Also applies to: 1419-1419, 1421-1421, 1423-1423


1438-1438: Documentation default value updated.

The docstring has been updated to show that the default value for head_first is now False.


1442-1442: Return type documentation updated.

The return value documentation has been updated to reflect that with head_first=False (now the default), the output shape will be [B, T, H, V].


1476-1479: Error messages reformatted and updated.

The error messages have been reformatted and updated to maintain consistency with the parameter default change.

Also applies to: 1481-1483, 1485-1488

fla/ops/generalized_delta_rule/iplr/chunk.py (9)

222-222: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, aligning with the PR objective.


272-272: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, maintaining consistency across the codebase.


352-352: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, maintaining consistency across the codebase.


413-413: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False in the forward method, maintaining consistency across the codebase.


467-467: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False in the public chunk_iplr_delta_rule function, maintaining consistency across the codebase.


472-481: Documentation updates align well with the parameter change.

The documentation for input tensor shapes has been updated to reflect the new default value of head_first=False. This ensures users understand the expected format of input tensors.


495-495: Documentation for head_first parameter updated correctly.

The default value in the documentation now matches the code change from True to False.


499-499: Output shape documentation updated correctly.

The documentation for the output tensor shape has been updated to reflect the shape when head_first=False as the default case.


508-520: Error message formatting improved.

The error messages have been reformatted for better readability while maintaining the same logical conditions for when they are raised.

fla/ops/gsa/chunk.py (12)

619-619: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, aligning with the PR objective.


649-649: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, maintaining consistency across the codebase.


738-738: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, maintaining consistency across the codebase.


775-775: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, maintaining consistency across the codebase.


927-927: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, maintaining consistency across the codebase.


983-983: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, maintaining consistency across the codebase.


1059-1059: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False in the forward method, maintaining consistency across the codebase.


1153-1153: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False in the public chunk_gsa function, maintaining consistency across the codebase.


1158-1164: Documentation updates align well with the parameter change.

The documentation for input tensor shapes has been updated to reflect the new default value of head_first=False. This ensures users understand the expected format of input tensors.


1190-1190: Documentation for head_first parameter updated correctly.

The default value in the documentation now matches the code change from True to False.


1194-1194: Output shape documentation updated correctly.

The documentation for the output tensor shape has been updated to reflect the shape when head_first=False as the default case.


1231-1243: Error message formatting improved.

The error messages have been reformatted for better readability while maintaining the same logical conditions for when they are raised.

fla/ops/simple_gla/parallel.py (6)

487-487: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, aligning with the PR objective.


555-555: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, maintaining consistency across the codebase.


682-682: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False in the public parallel_simple_gla function, maintaining consistency across the codebase.


687-693: Documentation updates align well with the parameter change.

The documentation for input tensor shapes has been updated to reflect the new default value of head_first=False. This ensures users understand the expected format of input tensors.


701-701: Documentation for head_first parameter updated correctly.

The default value in the documentation now matches the code change from True to False.


708-708: Output shape documentation updated correctly.

The documentation for the output tensor shape has been updated to reflect the shape when head_first=False as the default case.

fla/ops/ttt/naive.py (1)

86-86: Default parameter value update looks good.

The default value for head_first parameter has been changed from True to False, aligning with the PR objective. The function logic already accounts for both cases, so this change only affects the default behavior.

fla/ops/rwkv7/chunk.py (2)

27-37: Docstring update correctly reflects the new parameter order.

The changed parameter descriptions now correctly list the shape for head_first=False first, which aligns with the new default value.


51-51: Updated default value documentation matches the parameter's default value.

The docstring now correctly indicates that the default value for head_first is False, which matches the parameter declaration on line 22.

fla/ops/rwkv7/fused_recurrent.py (2)

27-37: Docstring update correctly reflects the new parameter order.

The changed parameter descriptions now correctly list the shape for head_first=False first, which aligns with the new default value.


49-49: Updated default value documentation matches the parameter's default value.

The docstring now correctly indicates that the default value for head_first is False, which matches the parameter declaration on line 22.

fla/ops/based/fused_chunk.py (1)

362-362:

Details

✅ Verification successful

Default value update for head_first parameter.

The default value of head_first has been changed from True to False, making tensor shapes of format [B, T, H, K] the default instead of [B, H, T, K]. This change is part of the PR's objective to update the default value across all functions.


🏁 Script executed:

#!/bin/bash
# Find other occurrences of the fused_chunk_based function to check for potential impacts
rg "fused_chunk_based\(" --type py

Length of output: 443


Verified Default Value Update for head_first Parameter

The change from True to False for the head_first parameter in fla/ops/based/fused_chunk.py has been confirmed. We verified that all usages of fused_chunk_based—including those in tests/ops/test_based.py and benchmarks/ops/benchmark_based.py—rely on the updated default tensor shape of [B, T, H, K], consistent with the PR's objective. No additional modifications are needed.

fla/ops/based/parallel.py (1)

398-398:

Details

✅ Verification successful

Default value update for head_first parameter.

The default value of head_first has been changed from True to False, making tensor shapes of format [B, T, H, K] the default instead of [B, H, T, K]. This change aligns with the PR's objective to update this parameter across all functions.


🏁 Script executed:

#!/bin/bash
# Find other occurrences of the parallel_based function to check for potential impacts
rg "parallel_based\(" --type py

Length of output: 1156


Action: Default Value Update for head_first Parameter Confirmed

The change to set the default value of head_first to False in fla/ops/based/parallel.py (line 398) has been verified. A codebase search confirms that:

  • Functions like parallel_based in tests and benchmarks rely on the default, aligning with the updated tensor shape [B, T, H, K].
  • Explicit overrides (e.g., in fla/layers/based.py) are used where necessary.

No further modifications are needed.

fla/ops/linear_attn/chunk.py (4)

21-21: API default value change: head_first now defaults to False.

This default change aligns with the PR objective to update the default tensor format from head-first ([B, H, T, K]) to head-last ([B, T, H, K]). The function behavior remains correct for both tensor formats.


26-31: Documentation updated correctly to reflect the new default tensor format.

The documentation now properly reflects that when head_first=False (the new default), tensors have shape [B, T, H, K] rather than [B, H, T, K].


40-42: Default value documentation updated appropriately.

The default value in the docstring has been updated to match the parameter declaration.


45-45: Return type documentation updated to reflect the new default format.

The documentation for the output tensor shape now reflects the new default format ([B, T, H, V] when head_first=False).

fla/ops/lightning_attn/fused_recurrent.py (4)

22-22: API default value change: head_first now defaults to False.

This parameter change aligns with the PR objective of updating defaults across the codebase.


27-31: Documentation updated to reflect the new default tensor shapes.

The docstring for parameter shapes has been updated correctly for the new default value.


49-50: Parameter docstring updated to match the new default.

The documentation now correctly states that the default value is False.


54-54: Return shape documentation updated consistently.

The documentation for the output tensor shape now reflects the new default format.

fla/ops/abc/chunk.py (4)

1085-1085: API default value change: head_first now defaults to False.

This parameter change aligns with the PR objective of updating defaults across the codebase.


1090-1095: Documentation updated to reflect the new default tensor shapes.

The docstring for parameter shapes has been updated correctly for the new default value.


1102-1103: Parameter docstring updated to match the new default.

The documentation now correctly states that the default value is False.


1107-1107: Return shape documentation updated consistently.

The documentation for the output tensor shape now reflects the new default format.

fla/ops/linear_attn/fused_chunk.py (4)

281-281: API default value change: head_first now defaults to False.

This parameter change aligns with the PR objective of updating defaults across the codebase.


286-291: Documentation updated to reflect the new default tensor shapes.

The docstring for parameter shapes has been updated correctly for the new default value.


301-301: Parameter docstring updated to match the new default.

The documentation now correctly states that the default value is False.


305-305: Return shape documentation updated consistently.

The documentation for the output tensor shape now reflects the new default format.

fla/ops/lightning_attn/chunk.py (2)

26-31: Documentation updates to reflect new tensor shape expectations.

Documentation for input parameters q, k, v and the output value o has been appropriately updated to reflect the tensor shape changes based on the new default value of head_first.

Also applies to: 49-50, 54-54


22-22:

Details

❓ Verification inconclusive

Default parameter value change affecting tensor shape expectations.

The default value of head_first has been changed from True to False. This is a breaking change that affects the expected shapes of input and output tensors when users don't explicitly specify this parameter.

Check if there are any usages across the codebase that might rely on the previous default value:


🏁 Script executed:

#!/bin/bash
# Search for calls to chunk_lightning_attn without explicit head_first
rg "chunk_lightning_attn\([^)]*\)" -A 5 | grep -v "head_first\s*="

Length of output: 68


Action Required: Reassess Default head_first Behavior

The default value for head_first in fla/ops/lightning_attn/chunk.py (line 22) has been changed from True to False. This change may impact tensor shape expectations in scenarios where users do not explicitly pass in this parameter.

  • Impact: Without an override, users might see unexpected tensor shapes due to the altered default.
  • Automated Search: Our automated grep command did not return any calls to chunk_lightning_attn missing an explicit head_first parameter. However, given the lack of output, this result is inconclusive.
  • Recommendation: A manual verification is advised to ensure that there are no edge cases or less-obvious call patterns relying on the previous default.
fla/ops/rwkv6/fused_recurrent.py (2)

621-629: Documentation updates to reflect new tensor shape expectations.

Documentation for input parameters r/q, k, v, w and the output value o has been appropriately updated to reflect the tensor shape changes based on the new default value of head_first.

Also applies to: 648-648, 652-652


686-699: Error message formatting improvement.

The error messages have been reformatted to improve readability while maintaining the same content. Multi-line format with f-strings makes the error messages more readable in the code.

fla/ops/ttt/fused_chunk.py (2)

852-852: Documentation update to reflect new default value.

The documentation for the head_first parameter has been updated to correctly specify the new default value of False.


867-880: Error message formatting improvement.

The error messages have been reformatted to improve readability while maintaining the same content. Multi-line format with f-strings makes the error messages more readable in the code.

fla/ops/utils/cumsum.py (1)

232-232:

Details

❓ Verification inconclusive

Default parameter value change in multiple function signatures.

The default value of head_first has been changed from True to False in several functions: chunk_local_cumsum_scalar, chunk_local_cumsum_vector, chunk_global_cumsum_scalar, chunk_global_cumsum_vector, chunk_global_cumsum, and chunk_local_cumsum. This is a breaking change that affects the expected shapes of input and output tensors.

Check for existing usages that might rely on the previous default behavior:

Also applies to: 266-266, 303-303, 333-333, 366-366, 388-388


🏁 Script executed:

#!/bin/bash
# Search for calls to these functions without explicit head_first
rg "chunk_(local|global)_cumsum(_scalar|_vector)?\([^)]*\)" -A 5 | grep -v "head_first\s*="

Length of output: 10558


Action: Verify External Reliance on Default head_first Behavior

The default value for head_first has been changed from True to False in several cumulative sum functions (specifically, in chunk_local_cumsum_scalar, chunk_local_cumsum_vector, chunk_global_cumsum_scalar, chunk_global_cumsum_vector, chunk_global_cumsum, and chunk_local_cumsum). Our grep search shows that all internal call sites (e.g., in tests under tests/ops/) explicitly pass a head_first variable rather than relying on the default. This verifies that the internal usage does not inadvertently depend on the old default behavior, so internal tests are unaffected.

However, note that this remains a breaking change for external callers who previously relied on the default True behavior. Please ensure that all external usage is reviewed to confirm that users have been appropriately notified and that any necessary migration steps are documented.

  • Impacted Files/Lines:

    • fla/ops/utils/cumsum.py: Lines 232, 266, 303, 333, 366, 388
    • Additional test files and modules invoking cumulative sum functions
  • Diff Snippet (unchanged for review context):

        head_first: bool = False,
fla/ops/ttt/chunk.py (14)

729-729: Consistent API change for head_first parameter default.

The default value for head_first parameter has been updated from True to False, bringing consistency with other modules in the codebase.


801-801: Consistent parameter default change.

The default value for head_first parameter has been changed to False to maintain consistency across the API.


856-856: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


926-926: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


982-982: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1078-1078: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1178-1178: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1241-1241: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1292-1292: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1466-1466: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1498-1499: Documentation updated to reflect the new default value.

The documentation comment has been updated to indicate that the default value for head_first is now False, which is consistent with the parameter change.


1512-1515: Error message updated to reflect variable-length input constraints.

The error message has been improved to better explain the requirement for batch size when using cu_seqlens.


1517-1519: Error message updated for head-first mode limitations.

The error message now clearly indicates that variable-length sequences aren't supported with head-first mode.


1521-1524: Error message updated for initial state requirements.

The error message has been improved for clarity when the number of initial states doesn't match the number of input sequences.

fla/ops/rwkv6/chunk.py (17)

81-81: Consistent API change for head_first parameter default.

The default value for head_first parameter has been updated from True to False, bringing consistency with other modules in the codebase.


855-855: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


966-966: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1024-1024: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1080-1080: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1137-1137: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1197-1197: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1377-1377: Consistent parameter default change.

The head_first parameter default has been updated to False to align with the broader API change.


1382-1382: Documentation updated to reflect the new tensor shape expectations.

The documentation now correctly indicates the expected shape of the queries with the new default head_first=False parameter value.


1384-1384: Documentation updated for key tensor shape.

The documentation has been updated to reflect the expected shape of the keys with the new head_first=False default.


1386-1386: Documentation updated for value tensor shape.

The documentation has been updated to reflect the expected shape of the values with the new head_first=False default.


1388-1388: Documentation updated for forget gates tensor shape.

The documentation has been updated to reflect the expected shape of the forget gates with the new head_first=False default.


1405-1405: Documentation updated with new default value.

The documentation now correctly indicates that False is the default value for the head_first parameter.


1409-1409: Documentation updated for output tensor shape.

The documentation has been updated to reflect the expected shape of the outputs with the new head_first=False default.


1444-1447: Error message updated to reflect variable-length input constraints.

The error message has been improved to better explain the requirement for batch size when using cu_seqlens.


1449-1451: Error message updated for head-first mode limitations.

The error message now clearly indicates that variable-length sequences aren't supported with head-first mode.


1453-1456: Error message updated for initial state requirements.

The error message has been improved for clarity when the number of initial states doesn't match the number of input sequences.

Comment thread fla/ops/common/chunk_h.py
output_final_state: bool,
offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
head_first: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default parameter value changed from True to False

The default value of head_first parameter has been changed from True to False. This affects how tensor dimensions are interpreted in the chunk_fwd_h function. When head_first is False, the function expects input tensors to have shape [B, T, H, K] rather than [B, H, T, K].

This is part of a larger change across the codebase standardizing on head_first=False as the default format.

This change may impact existing code that relies on the default value. Ensure that all callers either explicitly set this parameter or are updated to work with the new tensor dimension ordering.

Comment thread fla/ops/common/chunk_h.py
scale: float,
offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
head_first: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default parameter value changed from True to False

The default value of head_first parameter has been changed from True to False. This affects how tensor dimensions are interpreted in the chunk_bwd_dh function. When head_first is False, the function expects input tensors to have shape [B, T, H, K] rather than [B, H, T, K].

This is part of a larger change across the codebase standardizing on head_first=False as the default format.

This change may impact existing code that relies on the default value. Ensure that all callers either explicitly set this parameter or are updated to work with the new tensor dimension ordering.

offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default parameter value changed from True to False

The default value of head_first parameter has been changed from True to False. This affects how tensor dimensions are interpreted in the chunk_dplr_fwd_h function. When head_first is False, the function expects input tensors to have shape [B, T, H, K] rather than [B, H, T, K].

This is part of a larger change across the codebase standardizing on head_first=False as the default format.

This change may impact existing code that relies on the default value. Ensure that all callers either explicitly set this parameter or are updated to work with the new tensor dimension ordering.

offsets: Optional[torch.LongTensor],
indices: Optional[torch.LongTensor],
head_first: bool = True,
head_first: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default parameter value changed from True to False

The default value of head_first parameter has been changed from True to False. This affects how tensor dimensions are interpreted in the fwd_prepare_wy_repr function. When head_first is False, the function expects input tensors to have shape [B, T, H, K] rather than [B, H, T, K].

This is part of a larger change across the codebase standardizing on head_first=False as the default format.

This change may impact existing code that relies on the default value. Ensure that all callers either explicitly set this parameter or are updated to work with the new tensor dimension ordering.

output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Default parameter value changed from True to False

The default value of head_first parameter has been changed from True to False. This affects how tensor dimensions are interpreted in the chunk_retention function. When head_first is False, the function expects input tensors to have shape [B, T, H, K] rather than [B, H, T, K].

This is part of a larger change across the codebase standardizing on head_first=False as the default format.

This change may impact existing code that relies on the default value. Ensure that all callers either explicitly set this parameter or are updated to work with the new tensor dimension ordering.

initial_state: torch.Tensor,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Default value changed from True to False

The default value of head_first parameter for fused_recurrent_delta_rule_bwd has been changed from True to False. This is a breaking change for code that relies on the default behavior without explicitly specifying this parameter.


🏁 Script executed:

#!/bin/bash
# Identify potential usage of fused_recurrent_delta_rule_bwd without explicitly specifying head_first
rg "fused_recurrent_delta_rule_bwd\s*\(" --type py | grep -v "head_first\s*="

Length of output: 253


Action Required: Explicitly Specify head_first in fused_recurrent_delta_rule_bwd Calls

The verification confirms that the call sites for fused_recurrent_delta_rule_bwd (e.g. in fla/ops/delta_rule/fused_recurrent.py) are not providing an explicit value for head_first. Since the default has changed from True to False, this will alter behavior in these cases, potentially breaking backward compatibility.

  • Location: fla/ops/delta_rule/fused_recurrent.py (definition at line 362 and other call sites)
  • Issue: Call sites of fused_recurrent_delta_rule_bwd rely on the old default behavior (True), but now default to False.
  • Recommendation: Update all calls to explicitly provide the head_first parameter to match the intended behavior (e.g., append head_first=True if the previous behavior is desired).

output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = True,
head_first: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Default value changed from True to False

The default value of head_first parameter for fused_recurrent_delta_rule has been changed from True to False. This is a breaking change for code that relies on the default behavior without explicitly specifying this parameter.


🏁 Script executed:

#!/bin/bash
# Identify potential usage of fused_recurrent_delta_rule without explicitly specifying head_first
rg "fused_recurrent_delta_rule\s*\(" --type py | grep -v "head_first\s*="

Length of output: 830


Critical: Adjust Default Value for head_first in fused_recurrent_delta_rule

The default for the head_first parameter in fused_recurrent_delta_rule has been changed from True to False (line 504 in fla/ops/delta_rule/fused_recurrent.py). Our search shows that multiple calls—specifically in tests/ops/test_delta.py (and one in fla/layers/delta_net.py)—invoke this function without explicitly setting the head_first parameter, which means they unintentionally rely on its default value.

Action Required:

  • Review & Update Tests: Ensure that test cases invoking fused_recurrent_delta_rule explicitly pass the expected head_first value or adjust their expectations based on the new default.
  • Backward Compatibility: If the change was not intended, consider reverting the default or providing a migration plan to mitigate breaking changes for existing users.

reverse: bool = False,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Default parameter value change in multiple function signatures.

The default value of head_first has been changed from True to False in several functions: fused_recurrent_rwkv6_fwd, fused_recurrent_rwkv6_bwd, FusedRecurrentRWKV6Function.forward, and fused_recurrent_rwkv6. This is a breaking change that affects the expected shapes of input and output tensors.

Check for existing usages that might rely on the previous default behavior:

Also applies to: 448-448, 561-561, 617-617


🏁 Script executed:

#!/bin/bash
# Search for calls to these functions without explicit head_first
rg "fused_recurrent_rwkv6(_fwd|_bwd)?\([^)]*\)" -A 5 | grep -v "head_first\s*="

Length of output: 984


Action Required: Address Breaking Change in head_first Default Value

The default value of head_first was intentionally changed from True to False in multiple functions (fused_recurrent_rwkv6_fwd, fused_recurrent_rwkv6_bwd, FusedRecurrentRWKV6Function.forward, and fused_recurrent_rwkv6). This change directly impacts the expected tensor shapes, particularly for call sites that do not explicitly specify the head_first parameter. For example, the test file tests/ops/test_rwkv6.py calls fused_recurrent_rwkv6 without providing a value for head_first, meaning those tests were implicitly relying on the previous default behavior.

Please review all such usages to:

  • Update test cases or any other callers to explicitly set head_first if the original behavior is required.
  • Verify that the downstream tensor shapes align with the intended new behavior when head_first is False.

Affected lines in fla/ops/rwkv6/fused_recurrent.py:

  • 397
  • 448
  • 561
  • 617

initial_state_bias: torch.Tensor = None,
offsets: Optional[torch.LongTensor] = None,
head_first: bool = True
head_first: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Default parameter value change in multiple function signatures.

The default value of head_first has been changed from True to False in several functions: fused_chunk_ttt_linear_bwd_h, fused_chunk_ttt_linear_bwd_dh, fused_chunk_ttt_linear_fwd, fused_chunk_ttt_linear_bwd, and fused_chunk_ttt_linear. This is a breaking change that affects the expected shapes of input and output tensors.

Check for existing usages that might rely on the previous default behavior:

Also applies to: 550-550, 621-621, 680-680, 820-820


🏁 Script executed:

#!/bin/bash
# Search for calls to these functions without explicit head_first
rg "fused_chunk_ttt_linear(_bwd|_fwd|_bwd_h|_bwd_dh)?\([^)]*\)" -A 5 | grep -v "head_first\s*="

Length of output: 615


Breaking Change Detected: Update All Call Sites for head_first Parameter

The default value for head_first has been changed from True to False in multiple functions (fused_chunk_ttt_linear, fused_chunk_ttt_linear_bwd_h, fused_chunk_ttt_linear_bwd_dh, fused_chunk_ttt_linear_fwd, and fused_chunk_ttt_linear_bwd). Our verification revealed that in at least one instance—in benchmarks/ops/benchmark_ttt.py—the function is invoked without explicitly setting head_first, which means it will now default to False instead of True. This modification can lead to unexpected tensor shape manipulations and affect downstream computations.

  • Affected call:
    In benchmarks/ops/benchmark_ttt.py, the function call

    o4, _, _ = fused_chunk_ttt_linear(q, k, v, w, b, eta, chunk_size=16)

    does not specify head_first, which now defaults to False.

  • Action Required:
    Please review and update the affected call sites (and any similar usages) to pass an explicit value for head_first if the intended behavior was to use the previous default True. This includes checking related call sites in function invocations to ensure consistent tensor shaping.

Comment thread fla/ops/utils/cumsum.py
@yzhangcs yzhangcs merged commit f67688a into main Apr 7, 2025
2 of 6 checks passed
@yzhangcs yzhangcs deleted the varlen branch April 7, 2025 11:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant