Skip to content

[Deprecated] head_first option removed for gla variants#334

Merged
yzhangcs merged 3 commits intomainfrom
head_first
Apr 8, 2025
Merged

[Deprecated] head_first option removed for gla variants#334
yzhangcs merged 3 commits intomainfrom
head_first

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Apr 8, 2025

Summary by CodeRabbit

  • Documentation

    • Clarified tensor shape descriptions and updated error messages for improved user guidance.
  • Refactor

    • Streamlined function interfaces by removing deprecated parameters and redundant conditional logic.
    • Standardized tensor format handling and adjusted default parameter values for consistent behavior.
  • User Warnings

    • Introduced deprecation warnings to inform users about legacy parameter usage and encourage updated practices.
  • Tests

    • Updated test configurations—including data types and tensor dimensions—to ensure enhanced reliability and consistency.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 8, 2025

Walkthrough

This PR makes extensive updates across multiple modules to clarify tensor shapes and remove deprecated parameters. The documentation for the head_first flag has been revised across functions, and many signatures have been simplified by removing the head_first (and associated indices or HEAD_FIRST) parameter. Consequently, tensor rearrangement logic, conditional branches, and warnings have been consolidated for consistent behavior. Additionally, tests now use torch.float16 instead of torch.bfloat16 and adjust index ranges and shape configurations accordingly.

Changes

File(s) Change Summary
fla/ops/abc/chunk.py, fla/ops/gsa/chunk.py, fla/ops/gsa/fused_recurrent.py Updated documentation clarifying tensor shape conditions based on the (reversed) head_first flag.
fla/ops/common/chunk_h.py, fla/ops/common/chunk_h_parallel.py, fla/ops/common/chunk_o.py, fla/ops/common/fused_recurrent.py, fla/ops/common/chunk_h_split.py Removed head_first/HEAD_FIRST and indices parameters with associated conditional logic; simplified pointer arithmetic and kernel launch parameters for a unified tensor handling.
fla/ops/gated_delta_rule/chunk.py, fla/ops/retention/chunk.py, fla/ops/retention/fused_recurrent.py, fla/ops/simple_gla/chunk.py, fla/ops/simple_gla/fused_recurrent.py Introduced deprecation warnings and consolidated tensor rearrangement operations to ensure consistent input/output formats when handling the deprecated head_first flag.
fla/ops/gla/naive.py, fla/ops/nsa/naive.py Adjusted tensor transposition and lambda rearrangements before type conversion to support flexible dimensions and correct ordering.
fla/ops/utils/cumsum.py Removed external indices parameter; integrated computation of indices using prepare_chunk_indices internally.
tests/ops/test_cumsum.py, tests/ops/test_gla.py, tests/ops/test_retention.py, tests/ops/test_simple_gla.py Updated test parameterizations: replaced torch.bfloat16 with torch.float16, adjusted dimension ordering, range values, and variable names for consistency.

Sequence Diagram(s)

sequenceDiagram
    participant U as User Code
    participant F as Chunk/Kernel Function
    participant K as Kernel/Lower-Level Operation

    U->>F: Call function with tensors (optional head_first flag)
    alt head_first is provided and True
        F->>F: Issue deprecation warning
        F->>F: Rearrange tensor from head-first to time-first format
    else head_first is False
        F->>F: Proceed with input tensors as provided
    end
    F->>K: Invoke kernel with standardized tensor pointers and parameters
    K-->>F: Return computed tensor result
    alt head_first was True
        F->>F: Rearrange output back to head-first format if needed
    end
    F->>U: Return final output tensor
Loading

Possibly related PRs

Poem

I'm a rabbit in a code-filled glen,
Hopping through refactors now and then.
Warnings chime and docs align,
With shapes and flags all set just fine.
Simplified flows make my heart sing,
In this code garden, I flourish like spring!
🐰💻

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (12)
fla/ops/retention/fused_recurrent.py (2)

25-28: Added deprecation warning for head_first parameter.

Good practice to warn users about the upcoming removal of the deprecated parameter. However, the warning should include a stacklevel parameter to ensure it points to the caller's code instead of the warning implementation.

-        warnings.warn(
-            "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
-        )
+        warnings.warn(
+            "head_first is deprecated and will be removed in a future version. "
+            "Please use head_first=False for now instead.",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

25-25: No explicit stacklevel keyword argument found

(B028)


30-36: Added defensive warning for potential tensor format mismatches.

Good defensive programming to detect when users might have passed tensors in the wrong format. As with the previous warning, this should include a stacklevel parameter.

-        warnings.warn(
-            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-            "when head_first=False was specified. "
-            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+        warnings.warn(
+            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
+            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+            "when head_first=False was specified. "
+            "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

31-31: No explicit stacklevel keyword argument found

(B028)

fla/ops/retention/chunk.py (2)

55-60: Good use of deprecation warning for the head_first parameter.

The warning correctly informs users that head_first is being deprecated and provides guidance on what to use instead.

Consider adding the stacklevel parameter to the warning to help users identify the source of the warning in their code:

warnings.warn(
    "head_first is deprecated and will be removed in a future version. "
    "Please use head_first=False for now instead.",
+   stacklevel=2
)
🧰 Tools
🪛 Ruff (0.8.2)

56-56: No explicit stacklevel keyword argument found

(B028)


61-67: Helpful format mismatch detection.

This check provides a useful warning when tensor shapes suggest potential incorrect usage, helping users debug issues related to tensor format.

Consider adding the stacklevel parameter to the warning to help users identify the source of the warning in their code:

warnings.warn(
    f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
    "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
    "when head_first=False was specified. "
    "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+   stacklevel=2
)
🧰 Tools
🪛 Ruff (0.8.2)

62-62: No explicit stacklevel keyword argument found

(B028)

fla/ops/simple_gla/fused_recurrent.py (2)

90-95: Added deprecation warning for head_first parameter.

The warning correctly informs users that head_first is being deprecated and provides guidance on using head_first=False.

Consider adding the stacklevel parameter to the warning to help users identify the source of the warning in their code:

warnings.warn(
    "head_first is deprecated and will be removed in a future version. "
    "Please use head_first=False for now instead.",
+   stacklevel=2
)
🧰 Tools
🪛 Ruff (0.8.2)

91-91: No explicit stacklevel keyword argument found

(B028)


96-102: Added input format validation.

This check helps users identify potential issues when their tensor shapes don't match the expected format, improving developer experience.

Consider adding the stacklevel parameter to the warning to help users identify the source of the warning in their code:

warnings.warn(
    f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
    "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
    "when head_first=False was specified. "
    "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+   stacklevel=2
)
🧰 Tools
🪛 Ruff (0.8.2)

97-97: No explicit stacklevel keyword argument found

(B028)

fla/ops/gla/fused_recurrent.py (2)

92-101: Comprehensive handling of head_first deprecation.

The code properly handles all tensors that need to be rearranged, including the conditional rearrangement of gk and gv if they are provided. The warning message is clear and informative.

Consider adding the stacklevel parameter to the warning to help users identify the source of the warning in their code:

warnings.warn(
    "head_first is deprecated and will be removed in a future version. "
    "Please use head_first=False for now instead.",
+   stacklevel=2
)
🧰 Tools
🪛 Ruff (0.8.2)

93-93: No explicit stacklevel keyword argument found

(B028)


102-108: Added input format validation.

The warning about potential format mismatch helps users identify and debug issues when tensor shapes don't match the expected format.

Consider adding the stacklevel parameter to the warning to help users identify the source of the warning in their code:

warnings.warn(
    f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
    "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
    "when head_first=False was specified. "
    "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+   stacklevel=2
)
🧰 Tools
🪛 Ruff (0.8.2)

103-103: No explicit stacklevel keyword argument found

(B028)

fla/ops/simple_gla/chunk.py (2)

252-256: Added warning for deprecated head_first parameter.

A proper warning message has been added to inform users that the head_first parameter is deprecated and will be removed in a future version. This is good practice for maintaining backward compatibility while guiding users toward the new API.

Consider adding a specific stacklevel to the warning to properly attribute it to the caller:

-        warnings.warn(
+        warnings.warn(
            "head_first is deprecated and will be removed in a future version. "
            "Please use head_first=False for now instead."
+            , stacklevel=2
        )
🧰 Tools
🪛 Ruff (0.8.2)

253-253: No explicit stacklevel keyword argument found

(B028)


258-264: Added warning for potential format mismatch.

A warning has been added to alert users when the input tensor shape suggests a potential format mismatch. This helps users identify and fix issues with their input data.

Consider adding a specific stacklevel to the warning to properly attribute it to the caller:

-        warnings.warn(
+        warnings.warn(
            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
            "when head_first=False was specified. "
            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
+            , stacklevel=2
        )
🧰 Tools
🪛 Ruff (0.8.2)

259-259: No explicit stacklevel keyword argument found

(B028)

fla/ops/gla/chunk.py (2)

1421-1433: Well-implemented deprecation warning

Good implementation of deprecation warning with clear guidance for users. The warning explains that head_first is deprecated and provides instruction on how to transition to the new format.

However, the warning is missing the stacklevel parameter which helps attribute the warning to the correct call site in user code.

Consider adding the stacklevel parameter:

-        warnings.warn(
+        warnings.warn(
            "head_first is deprecated and will be removed in a future version. "
            "Please use head_first=False for now instead."
+            stacklevel=2
         )
🧰 Tools
🪛 Ruff (0.8.2)

1422-1422: No explicit stacklevel keyword argument found

(B028)


1428-1428: No explicit stacklevel keyword argument found

(B028)


1428-1433: Helpful shape validation warning

The shape validation warning provides useful diagnostics to help users identify when they might be using tensors in the wrong format. This is particularly valuable during the transition period.

Similar to the previous warning, add the stacklevel parameter:

-        warnings.warn(
+        warnings.warn(
            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
            "when head_first=False was specified. "
            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
+            stacklevel=2
         )
🧰 Tools
🪛 Ruff (0.8.2)

1428-1428: No explicit stacklevel keyword argument found

(B028)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c060965 and bb6ae12.

📒 Files selected for processing (21)
  • fla/ops/abc/chunk.py (1 hunks)
  • fla/ops/common/chunk_h.py (4 hunks)
  • fla/ops/common/chunk_h_parallel.py (7 hunks)
  • fla/ops/common/chunk_o.py (9 hunks)
  • fla/ops/common/fused_recurrent.py (11 hunks)
  • fla/ops/gated_delta_rule/chunk.py (3 hunks)
  • fla/ops/gla/chunk.py (26 hunks)
  • fla/ops/gla/fused_recurrent.py (3 hunks)
  • fla/ops/gla/naive.py (2 hunks)
  • fla/ops/gsa/chunk.py (4 hunks)
  • fla/ops/gsa/fused_recurrent.py (1 hunks)
  • fla/ops/nsa/naive.py (3 hunks)
  • fla/ops/retention/chunk.py (2 hunks)
  • fla/ops/retention/fused_recurrent.py (3 hunks)
  • fla/ops/simple_gla/chunk.py (7 hunks)
  • fla/ops/simple_gla/fused_recurrent.py (3 hunks)
  • fla/ops/utils/cumsum.py (4 hunks)
  • tests/ops/test_cumsum.py (4 hunks)
  • tests/ops/test_gla.py (5 hunks)
  • tests/ops/test_retention.py (3 hunks)
  • tests/ops/test_simple_gla.py (7 hunks)
🧰 Additional context used
🧬 Code Definitions (8)
tests/ops/test_gla.py (2)
fla/ops/gla/chunk.py (1)
  • chunk_gla (1348-1463)
fla/ops/gla/fused_recurrent.py (1)
  • fused_recurrent_gla (13-141)
fla/ops/retention/fused_recurrent.py (1)
fla/ops/simple_gla/fused_recurrent.py (1)
  • fused_recurrent_simple_gla (13-133)
fla/ops/retention/chunk.py (1)
fla/ops/simple_gla/chunk.py (1)
  • chunk_simple_gla (179-294)
fla/ops/common/chunk_h.py (1)
fla/ops/common/utils.py (1)
  • prepare_chunk_offsets (65-69)
fla/ops/common/fused_recurrent.py (1)
fla/ops/utils/cumsum.py (1)
  • chunk_global_cumsum (362-381)
fla/ops/common/chunk_h_parallel.py (1)
fla/ops/common/utils.py (2)
  • prepare_chunk_indices (56-61)
  • prepare_chunk_offsets (65-69)
fla/ops/utils/cumsum.py (1)
fla/ops/common/utils.py (1)
  • prepare_chunk_indices (56-61)
fla/ops/common/chunk_o.py (1)
fla/ops/common/utils.py (1)
  • prepare_chunk_indices (56-61)
🪛 Ruff (0.8.2)
fla/ops/retention/fused_recurrent.py

25-25: No explicit stacklevel keyword argument found

(B028)


31-31: No explicit stacklevel keyword argument found

(B028)

fla/ops/retention/chunk.py

56-56: No explicit stacklevel keyword argument found

(B028)


62-62: No explicit stacklevel keyword argument found

(B028)

fla/ops/gla/fused_recurrent.py

93-93: No explicit stacklevel keyword argument found

(B028)


103-103: No explicit stacklevel keyword argument found

(B028)

fla/ops/simple_gla/fused_recurrent.py

91-91: No explicit stacklevel keyword argument found

(B028)


97-97: No explicit stacklevel keyword argument found

(B028)

fla/ops/simple_gla/chunk.py

253-253: No explicit stacklevel keyword argument found

(B028)


259-259: No explicit stacklevel keyword argument found

(B028)

fla/ops/gla/chunk.py

1422-1422: No explicit stacklevel keyword argument found

(B028)


1428-1428: No explicit stacklevel keyword argument found

(B028)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (107)
fla/ops/gla/naive.py (2)

21-21: Tensor dimensions standardized with transpose operation.

This change ensures that the tensors are consistently in the format [B, H, T, ...] by transposing dimensions 1 (sequence length) and 2 (heads) before conversion to float. This aligns with the PR objective of removing the head_first option.


41-41: Output tensor format standardized.

The output tensor is now transposed back to the expected format before converting to the specified data type, ensuring consistency with the input tensor format adjustment on line 21.

fla/ops/abc/chunk.py (1)

1096-1096: Documentation updated with correct tensor shape description.

The updated documentation for the s parameter correctly clarifies the expected tensor shape based on the head_first flag. This ensures users understand the proper input format.

fla/ops/gsa/fused_recurrent.py (1)

480-480: Documentation updated with correct tensor shape description.

The updated documentation for the s parameter correctly clarifies the expected tensor shape based on the head_first flag. This ensures users understand the proper input format and maintains consistency with similar documentation updates in other files.

fla/ops/retention/fused_recurrent.py (4)

4-4: Added necessary imports for deprecation handling.

The added imports for warnings and rearrange from einops support the new functionality for deprecation warnings and tensor format conversion.

Also applies to: 8-8


29-29: Added tensor format conversion for backward compatibility.

This change ensures that tensors are converted from head-first format to sequence-first format when head_first=True, allowing internal processing to be standardized.


37-38: Simplified tensor shape calculation and expansion.

The calculation of s now directly uses q.shape[2] for the range, and g is expanded consistently regardless of the head_first parameter value. This simplifies the code and removes conditional logic.


50-52: Added output format conversion for backward compatibility.

This change ensures that if head_first=True, the output tensor is converted back to head-first format before returning, maintaining backward compatibility with existing code.

fla/ops/retention/chunk.py (4)

4-4: Import requirements properly added for the new functionality.

The added imports for warnings and rearrange from einops are necessary to support the new deprecation warnings and tensor reshaping features.

Also applies to: 8-8


68-69: Simplified tensor shape handling.

The computation of s and expansion of g have been simplified by directly using the shape of q tensor, making the code more maintainable.


70-79: Properly removed head_first parameter from function call.

Successfully removed the head_first parameter from the chunk_simple_gla call, which is consistent with the PR objective of deprecating this parameter.


80-82: Format conversion for backward compatibility.

This code correctly handles the backward compatibility by rearranging the output tensor back to head-first format when the parameter is set to True.

tests/ops/test_gla.py (5)

52-55: Updated tensor shapes to use time-first format.

The tensor dimensions have been changed from (B, H, T, D) (head-first) to (B, T, H, D) (time-first) format, aligning with the PR objective to standardize on time-first format.


139-147: Improved readability with multi-line formatting.

The function call to chunk_gla has been reformatted to use multi-line arguments, which improves code readability while maintaining the same functionality.


155-172: Consistently formatted function calls.

The function calls to fused_recurrent_gla are now consistently formatted with one argument per line, improving readability and maintainability.


208-211: Updated variable name and generation logic.

The variable offsets has been renamed to cu_seqlens to match the parameter name in the called functions, and the logic for generating its values has been updated to ensure a minimum sequence length of 16.


225-226: Consistent parameter naming.

The offsets parameter has been renamed to cu_seqlens in all function calls, maintaining consistency with the parameter names in the function definitions.

Also applies to: 231-232, 248-249

fla/ops/simple_gla/fused_recurrent.py (3)

4-4: Added required imports for new functionality.

The imports for warnings and rearrange from einops have been added to support the deprecation warnings and tensor reshaping functionality.

Also applies to: 8-8


120-130: Removed head_first parameter from function call.

The head_first parameter has been properly removed from the fused_recurrent function call, aligning with the PR objective to deprecate this parameter.


131-133: Format conversion for backward compatibility.

The code correctly handles backward compatibility by rearranging the output tensor back to head-first format when the parameter is set to True.

fla/ops/gla/fused_recurrent.py (2)

4-4: Added required imports for new functionality.

The imports for warnings and rearrange from einops have been added to support the deprecation warnings and tensor reshaping functionality.

Also applies to: 8-8


139-141: Format conversion for backward compatibility.

The code correctly handles backward compatibility by rearranging the output tensor back to head-first format when the parameter is set to True.

fla/ops/gated_delta_rule/chunk.py (3)

351-351: Documentation improved

The assertion message now provides clearer guidance about the expected shape of the beta tensor based on the head_first parameter.


369-369: Enhanced tensor rearrangement flexibility

The updated rearrangement pattern using ellipsis (...) allows handling tensors with varying trailing dimensions more elegantly, making the code more robust to different tensor shapes.


386-386: Consistent rearrangement pattern

This modification matches the input tensor rearrangement approach (using ellipsis), maintaining consistency in how tensors are handled throughout the function.

tests/ops/test_simple_gla.py (5)

40-40: Parameter renamed for better clarity

The parameter name change from BT to chunk_size improves code readability by using a more descriptive name that better represents the parameter's purpose.


53-53: Maintained compatibility with variable name

Good approach to assign the new parameter to the existing variable name, preserving the internal implementation while improving the API.


63-63: Improved variable naming

Changed variable name from l to t better represents that this dimension corresponds to sequence length or time dimension, improving code readability.


66-66: Updated parameter reference

Updated to use the new parameter name chunk_size for consistency with the function signature.


73-73: Consistent variable and parameter naming

Loop condition updated to use the renamed variable t and parameter chunk_size, maintaining logical consistency.

fla/ops/nsa/naive.py (3)

30-30: Fixed tensor shape documentation

Corrected the documentation for the indices parameter shape based on the head_first parameter, making it more accurate and consistent with the actual implementation.


55-55: More flexible tensor rearrangement

Updated to use ellipsis (...) in the rearrangement pattern, allowing the function to handle tensors with varying trailing dimensions more flexibly.


95-95: Consistent tensor handling

The output tensor rearrangement now matches the input handling style, creating consistency throughout the function.

tests/ops/test_retention.py (2)

33-33: Changed test dtype from bfloat16 to float16

Similar to the changes in test_simple_gla.py, the test dtype has been standardized to torch.float16 instead of torch.bfloat16.

This appears to be a consistent change across test files. Is there a particular reason for switching from bfloat16 to float16? This might affect numerical precision in some edge cases.

Also applies to: 87-87


107-107: Modified random index generation range

The test now uses a broader exclusion range for random permutation (T-16 instead of T-1), which changes how random segments are distributed during testing.

This change might affect the distribution of test cases. Can you ensure this won't reduce test coverage or create edge cases that aren't being tested properly?

tests/ops/test_cumsum.py (4)

55-55: Updated the test dtype from bfloat16 to float16.

The test parameterization has been changed to use torch.float16 instead of torch.bfloat16. This change aligns with similar updates in other test files to standardize the data types used in testing.


75-75: Updated the test dtype from bfloat16 to float16.

Consistent with the previous change, this test now uses torch.float16 instead of torch.bfloat16, standardizing dtype usage across tests.


97-97: Updated the test dtype from bfloat16 to float16.

The test parameterization has been changed to use torch.float16 instead of torch.bfloat16, maintaining consistency across all test functions in this file.


121-121: Updated the test dtype from bfloat16 to float16.

The final test function's dtype parameterization has been changed to torch.float16, completing the standardization of dtype usage throughout the test file.

fla/ops/common/chunk_h.py (10)

309-309: Simplified tensor shape unpacking.

The code now directly unpacks the shape from k and v tensors, removing the conditional logic that previously depended on the head_first parameter. This simplifies the code and makes tensor shape handling more consistent.


315-316: Streamlined variable assignment for equal-length sequences.

The assignment of N, NS, and split_offsets variables has been simplified when processing equal-length sequences, improving code readability.


318-318: Simplified variable assignment for variable-length sequences.

The assignment of N and NS variables has been streamlined when processing variable-length sequences, making the code more concise.


320-320: Modified tensor initialization shape.

The tensor h initialization has been updated to use a more consistent shape format, removing conditional logic that previously depended on the head_first parameter.


343-343: Hardcoded HEAD_FIRST parameter to False.

The HEAD_FIRST parameter is now explicitly set to False rather than being dynamically determined, reflecting the removal of the head_first option from the function signature. This simplifies the code and enforces a consistent tensor layout.


364-365: Simplified tensor shape unpacking and HQ assignment.

The code now directly unpacks dimensions from tensor shapes and assigns HQ from q.shape[2], removing conditional logic that previously depended on the head_first parameter.


372-373: Streamlined variable assignment for equal-length sequences.

The assignment of N, NS, and split_offsets variables has been simplified for the backward pass when processing equal-length sequences.


375-375: Simplified variable assignment for variable-length sequences.

The assignment of N and NS variables has been streamlined for the backward pass when processing variable-length sequences.


378-378: Modified tensor initialization shape for backward pass.

The tensor dh initialization has been updated to use a more consistent shape format for the backward pass, removing conditional logic that previously depended on the head_first parameter.


405-405: Hardcoded HEAD_FIRST parameter to False in backward pass.

The HEAD_FIRST parameter is now explicitly set to False in the backward pass kernel, ensuring consistent tensor layout handling throughout forward and backward operations.

fla/ops/gsa/chunk.py (4)

9-9: Added einops import for tensor reshaping.

The import of rearrange and reduce functions from einops facilitates tensor reshaping operations used later in the code, particularly for handling different tensor layouts.


1153-1153: Changed the default value of head_first parameter from True to False.

The default value for the head_first parameter has been changed from True to False, reflecting the deprecation of the head_first option as indicated in the PR title. This aligns with the broader effort to standardize tensor shapes across the codebase.


1165-1165: Updated documentation to reflect the correct tensor shape.

The documentation for the s parameter has been updated to correctly describe the expected tensor shape when head_first=False. This ensures that users understand the expected input format after the parameter deprecation.


1271-1272: Added conditional tensor reshaping to maintain backward compatibility.

A conditional block has been added to rearrange the output tensor o when head_first is True. This ensures backward compatibility while standardizing on the new tensor layout, allowing existing code that expects the head-first format to continue working.

fla/ops/common/chunk_h_parallel.py (11)

14-14: Updated imports to include prepare_chunk_indices function.

The import statement has been updated to include prepare_chunk_indices from the utils module, which is now used to compute indices for chunking instead of relying on the previously passed indices parameter.


493-493: Simplified tensor shape unpacking.

The code now directly unpacks the shape from k and v tensors, removing the conditional logic that previously depended on the head_first parameter. This simplifies the code and makes tensor shape handling more consistent.


499-500: Added automatic chunk indices computation.

Instead of requiring indices as a parameter, the function now computes them internally using prepare_chunk_indices when dealing with variable-length sequences. This simplifies the API while maintaining the same functionality.


502-502: Modified tensor initialization shape.

The tensor h initialization has been updated to use a more consistent shape format, matching the changes in chunk_h.py and removing the dependency on the head_first parameter.


524-524: Hardcoded HEAD_FIRST parameter to False.

The HEAD_FIRST parameter is now explicitly set to False rather than being dynamically determined, which simplifies the code and enforces a consistent tensor layout.


545-545: Hardcoded HEAD_FIRST parameter to False in reduction kernel.

The HEAD_FIRST parameter is now explicitly set to False in the reduction kernel as well, ensuring consistent tensor layout handling throughout all operations.


566-567: Simplified tensor shape unpacking for backward pass.

The code now directly unpacks dimensions from tensor shapes for the backward pass, removing conditional logic that previously depended on the head_first parameter.


574-575: Added automatic chunk indices computation for backward pass.

Similar to the forward pass, the backward pass now computes indices internally using prepare_chunk_indices when dealing with variable-length sequences, simplifying the API.


578-578: Modified tensor initialization shape for backward pass.

The tensor dh initialization has been updated to use a more consistent shape format for the backward pass, matching the changes in the forward pass and in chunk_h.py.


604-604: Hardcoded HEAD_FIRST parameter to False in backward parallel kernel.

The HEAD_FIRST parameter is now explicitly set to False in the backward parallel kernel, maintaining consistency with the forward pass changes.


628-628: Hardcoded HEAD_FIRST parameter to False in backward reduction kernel.

The HEAD_FIRST parameter is now explicitly set to False in the backward reduction kernel as well, completing the consistent application of this change throughout all kernel functions.

fla/ops/common/fused_recurrent.py (13)

53-54: Removed HEAD_FIRST parameter from kernel function signature.

The HEAD_FIRST parameter has been removed from the function signature, simplifying the interface and standardizing on a specific tensor layout. This makes the code cleaner and easier to maintain.


66-75: Improved memory access pattern by simplifying pointer calculations.

The conditional branching based on HEAD_FIRST has been removed, and the pointer calculations are now consistently based on the bos variable and the REVERSE flag. This change simplifies the logic and likely improves performance by reducing branches in the kernel.


103-112: Simplified pointer updates in the loop.

The pointer update calculations have been streamlined to only depend on the REVERSE flag, which makes the code more maintainable and easier to reason about.


175-184: Consistent pointer initialization in backward kernel.

Similar to the forward kernel, the backward kernel now has simplified pointer calculations that only depend on the bos variable and the REVERSE flag, making the code more consistent.


213-222: Simplified pointer updates in backward pass.

The pointer update logic in the backward pass is now cleaner and consistent with the changes in the forward pass.


227-238: Simplified pointer reinitialization for second phase of backward pass.

The pointer calculations for the second phase of the backward pass have been standardized to follow the same pattern as the rest of the code.


265-276: Consistent pointer update logic in second backward pass loop.

The pointer update logic in the second backward pass loop is now simpler and consistent with the rest of the code.


294-297: Removed head_first parameter from function signature.

The head_first parameter has been removed from the fused_recurrent_fwd function signature, aligning with the broader refactoring to standardize on a specific tensor layout.


302-303: Simplified final state tensor shape declaration.

The shape of the final state tensor ht is now defined directly using N, H, K, V dimensions, making the code more straightforward.


347-349: Removed head_first parameter from backward function signature.

The head_first parameter has been removed from the fused_recurrent_bwd function signature, consistent with the changes to the forward function.


395-399: Updated chunk_global_cumsum calls to match new API.

The calls to chunk_global_cumsum have been updated to remove the head_first parameter, aligning with the changes in that function's signature.


421-422: Removed head_first parameter from FusedRecurrentFunction.forward.

The head_first parameter has been removed from the autograd function's forward method, consistent with the other function signature changes.


471-472: Updated return statement in backward method.

The return statement in the backward method has been updated to match the changes in the function signatures.

fla/ops/simple_gla/chunk.py (9)

4-4: Added warnings import for deprecation notices.

The warnings module has been imported to support the deprecation warnings added for the head_first parameter.


9-9: Added einops import for tensor rearrangement.

The rearrange function from einops is now imported to handle tensor rearrangement when head_first=True is used.


28-28: Simplified chunk_local_cumsum call by removing head_first and indices parameters.

The call to chunk_local_cumsum now omits the head_first and indices parameters, relying on the internal computation of indices.


133-136: Simplified shape extraction in forward method.

The tensor shape extraction in the forward method is now simpler, directly using q.shape[1] to determine T.


159-160: Updated context saving in backward method.

The context now only saves chunk_size, scale, and offsets, removing the no longer needed head_first and indices parameters.


173-174: Simplified chunk_local_cumsum call in backward method.

The call to chunk_local_cumsum in the backward method has been updated to align with the new function signature.


175-175: Simplified return statement in backward method.

The return statement in the backward method has been updated to match the changes in function signatures.


257-257: Added tensor rearrangement for backward compatibility.

When head_first=True is used, the tensors are automatically rearranged to the standard format expected by the implementation. This ensures backward compatibility during the deprecation period.


292-294: Added tensor rearrangement for output when head_first=True.

When head_first=True is used, the output tensor is rearranged back to the head-first format for backward compatibility.

fla/ops/utils/cumsum.py (5)

10-10: Added import for prepare_chunk_indices function.

The prepare_chunk_indices function is now imported from fla.ops.common.utils, allowing the indices to be computed internally rather than passed as parameters.


243-244: Compute indices internally using prepare_chunk_indices.

Instead of receiving indices as a parameter, the function now computes it internally using prepare_chunk_indices when offsets is provided. This is a good design choice for encapsulation.


274-275: Added internal computation of indices in chunk_local_cumsum_vector.

Similar to chunk_local_cumsum_scalar, the chunk_local_cumsum_vector function now computes indices internally using prepare_chunk_indices.


391-393: Updated function signature with output_dtype and kwargs.

The chunk_local_cumsum function signature has been updated to include output_dtype as a named parameter and add **kwargs for backward compatibility.


397-400: Simplified function calls by removing indices parameter.

The calls to chunk_local_cumsum_scalar and chunk_local_cumsum_vector have been updated to pass output_dtype but omit the now-internally-computed indices.

fla/ops/common/chunk_o.py (9)

10-10: Added import for prepare_chunk_indices function.

The prepare_chunk_indices function is now imported from fla.ops.common.utils, similar to the changes in other files.


467-470: Updated shape extraction and indices computation in chunk_fwd_o.

The chunk_fwd_o function now uses a simpler shape extraction approach and computes indices internally using prepare_chunk_indices.


492-493: Set HEAD_FIRST parameter to False in kernel call.

The HEAD_FIRST parameter is now consistently set to False in the kernel call, reflecting the standardization on a specific tensor layout.


507-510: Updated shape extraction and indices computation in chunk_bwd_dv.

The chunk_bwd_dv function now uses a simpler shape extraction approach and computes indices internally using prepare_chunk_indices.


541-542: Set HEAD_FIRST parameter to False in backward kernel call.

The HEAD_FIRST parameter is now consistently set to False in the backward kernel call.


556-559: Updated shape extraction and indices computation in chunk_bwd_dv_local.

The chunk_bwd_dv_local function now uses a simpler shape extraction approach and computes indices internally using prepare_chunk_indices.


588-589: Set HEAD_FIRST parameter to False in dv_local kernel call.

The HEAD_FIRST parameter is now consistently set to False in the dv_local kernel call.


608-611: Updated shape extraction and indices computation in chunk_bwd_dqkwg.

The chunk_bwd_dqkwg function now uses a simpler shape extraction approach and computes indices internally using prepare_chunk_indices.


648-649: Set HEAD_FIRST parameter to False in dqkwg kernel call.

The HEAD_FIRST parameter is now consistently set to False in the dqkwg kernel call.

fla/ops/gla/chunk.py (7)

4-4: Good addition of the warnings module

Adding the warnings module is appropriate for implementing deprecation warnings.


856-857: Improved tensor shape handling

The explicit extraction of tensor dimensions makes the code more readable and clearer about the expected input format. This is a good change as it standardizes all tensors to use the [B, T, H, K/V] format.

Also applies to: 955-956, 989-990, 1023-1024, 1057-1058, 1206-1207, 1298-1299


858-859: Centralized index preparation

Using the prepare_chunk_indices helper function consistently across the codebase simplifies the logic and reduces code duplication.

Also applies to: 957-958, 991-993, 1025-1027, 1059-1060, 1106-1107


863-864: Simplified tensor allocation

The tensor allocation now uses a consistent format without conditional branches for different tensor layouts, which makes the code easier to understand and maintain.

Also applies to: 906-907, 995-996


879-880: Hardcoded HEAD_FIRST parameter

Setting HEAD_FIRST=False consistently throughout the codebase aligns with the PR objective to deprecate and remove the head_first option.

Also applies to: 900-901, 925-926, 940-941, 977-978, 1009-1010, 1044-1045, 1085-1086, 1135-1136


1153-1154: Consistent cumsum calculation

The chunk_local_cumsum function is now used consistently with the same parameter structure, improving code readability.

Also applies to: 1209-1210


1451-1463: Improved ChunkGLAFunction.apply call and backward compatibility

The refactored function call structure is cleaner and more maintainable. The conditional rearrangement of the output tensor when head_first=True ensures backward compatibility for users who haven't migrated yet.

@pytest.mark.parametrize("D", test_d_list)
@pytest.mark.parametrize("gate_logit_normalizer", test_gate_list)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float16])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Changed test dtype from bfloat16 to float16

All test cases now use torch.float16 instead of torch.bfloat16. This appears to be a deliberate change across multiple test functions.

Also applies to: 190-190, 266-266, 337-337


🏁 Script executed:

#!/bin/bash
# Check if there's a pattern of similar changes across the codebase
rg -A 1 "@pytest.mark.parametrize\(\"dtype\"" --glob "tests/**/*.py"

Length of output: 10380


Action Required: Clarify and Unify Test Dtype Usage in test_simple_gla.py

The change at line 118 in tests/ops/test_simple_gla.py updates the dtype parameter to use torch.float16 instead of torch.bfloat16. However, note the following based on our search:

  • Several parameterizations in this file (e.g., at lines 118, 190, 266) now use torch.float16, but there is at least one instance (around line 337) where the dtype is still set as [torch.float, torch.bfloat16].
  • Moreover, a quick scan of other test files (such as tests/test_model.py and tests/ops/test_utils.py) reveals that not all tests have switched from torch.bfloat16 to torch.float16.

Recommendations:

  • If the intent was to apply a uniform change across all test cases in test_simple_gla.py, update the remaining parameterization(s) (e.g., change [torch.float, torch.bfloat16] to the appropriate value) for consistency.
  • Consider whether this change should be applied globally across the entire test suite. If so, update other files that still use torch.bfloat16 or, if intentional, add a comment clarifying the rationale behind the mixed usage.

@zhiyuan1i zhiyuan1i force-pushed the main branch 2 times, most recently from d503477 to 17d5c21 Compare April 8, 2025 13:44
@yzhangcs yzhangcs merged commit 4168bb2 into main Apr 8, 2025
5 of 6 checks passed
@yzhangcs yzhangcs deleted the head_first branch April 8, 2025 13:46
@yzhangcs yzhangcs restored the head_first branch April 8, 2025 13:47
yzhangcs added a commit that referenced this pull request Apr 8, 2025
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
fla/ops/gsa/chunk.py (3)

10-10: Importing reduce
The import reduce from einops is introduced. Verify that it's actually used. If it is unused, consider removing to keep imports minimal.


550-550: Extract shape from k
Explicitly unpacking B, T, H, K, V = *k.shape, v.shape[-1] might cause confusion if k or v shapes differ from expected. Consider validating shapes or adding an assertion for clarity.


1006-1006: Defaulting head_first to False
Changing the default from True to False is a breaking API change for existing code. Communicate this clearly in release notes to avoid user confusion.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 112103e and 25ec0e0.

📒 Files selected for processing (4)
  • fla/ops/gla/chunk.py (36 hunks)
  • fla/ops/gsa/chunk.py (28 hunks)
  • fla/ops/gsa/fused_recurrent.py (10 hunks)
  • tests/ops/test_gsa.py (2 hunks)
🔇 Additional comments (19)
tests/ops/test_gsa.py (2)

147-147: Validate minimum T
The new range torch.randperm(T - 16) requires that T be at least 16. Please confirm or enforce that T >= 16 to avoid runtime errors.


329-329: Validate minimum T
Similarly, ensure T >= 16 for the random permutation logic to work without errors.

fla/ops/gsa/chunk.py (10)

4-4: Importing warnings
You've introduced an import for the warnings module, presumably for deprecation notices. Looks good.


13-14: Added chunk utilities
You're now importing prepare_chunk_indices and the chunk_gla_* modules. Ensure these references are used consistently and the imports are necessary.


92-93: Revised pointer initialization
These lines update how p_o and p_A pointers are created for block operations. Using (bos * HQ + i_hq) * V and (bos * HQ + i_hq) * BT appear correct, but please confirm alignment remains consistent with the expected tensor shapes.


554-554: Hard-coded assumption for HQ
Setting HQ = q.shape[2] assumes the shape is [B, T, HQ, ...]. If the user passes [B, H, T, ...], this could lead to out-of-bound indexing or shape mismatches.


557-557: Conditional index preparation
indices = prepare_chunk_indices(offsets, BT) is a neat approach for varlen sequences. Just ensure that downstream usage covers all edge cases when offsets is None.


1084-1087: Deprecation warning and rearrange
You've added a deprecation warning for head_first and force a shape rearrange from [b h t ...] to [b t h ...]. This is a helpful transitional measure.


1088-1094: Warning on shape mismatch
The warning clarifies potential user error if [B, T, H, ...] is accidentally passed in head-first format. Good practice to reduce confusion.


1292-1296: Deprecated head_first in chunk_gsa
This block again warns about deprecation and rearranges shapes if head_first is used. Keeping a consistent approach to deprecating head_first is beneficial.


1297-1304: Format mismatch warning
Similar logic to lines 1088-1094. Good job preserving consistent user notifications across files.


1322-1322: Returning final_state
This updated return statement ensures both output and final_state are produced. Make sure any downstream usage is updated accordingly.

fla/ops/gsa/fused_recurrent.py (2)

4-4: Add warnings import
Importing warnings to handle deprecation or user notifications is consistent with the approach in other files.


10-10: Added rearrange import
Imported rearrange from einops. Make sure any newly introduced shape transformations are correct.

fla/ops/gla/chunk.py (5)

4-4: Add warnings import
This import is consistent with the other GLA modules for deprecation notices.


10-10: Use of rearrange
Introducing rearrange from einops allows flexible tensor shape transformations.


1292-1296: Deprecation block for head_first
This block mirrors the pattern used elsewhere, providing a warning and rearranging shapes if head_first is used.


1297-1304: Shape mismatch checks
These lines add a warning if q.shape[1] < q.shape[2] when head_first=False. Good for preventing user confusion.


1322-1322: Return assignment
Returning o, final_state from the ChunkGLAFunction call. Confirm that all call sites expect this new structure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant