[Deprecated] head_first option removed for gla variants#334
Conversation
WalkthroughThis PR makes extensive updates across multiple modules to clarify tensor shapes and remove deprecated parameters. The documentation for the Changes
Sequence Diagram(s)sequenceDiagram
participant U as User Code
participant F as Chunk/Kernel Function
participant K as Kernel/Lower-Level Operation
U->>F: Call function with tensors (optional head_first flag)
alt head_first is provided and True
F->>F: Issue deprecation warning
F->>F: Rearrange tensor from head-first to time-first format
else head_first is False
F->>F: Proceed with input tensors as provided
end
F->>K: Invoke kernel with standardized tensor pointers and parameters
K-->>F: Return computed tensor result
alt head_first was True
F->>F: Rearrange output back to head-first format if needed
end
F->>U: Return final output tensor
Possibly related PRs
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (12)
fla/ops/retention/fused_recurrent.py (2)
25-28: Added deprecation warning forhead_firstparameter.Good practice to warn users about the upcoming removal of the deprecated parameter. However, the warning should include a
stacklevelparameter to ensure it points to the caller's code instead of the warning implementation.- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
25-25: No explicit
stacklevelkeyword argument found(B028)
30-36: Added defensive warning for potential tensor format mismatches.Good defensive programming to detect when users might have passed tensors in the wrong format. As with the previous warning, this should include a
stacklevelparameter.- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
31-31: No explicit
stacklevelkeyword argument found(B028)
fla/ops/retention/chunk.py (2)
55-60: Good use of deprecation warning for the head_first parameter.The warning correctly informs users that
head_firstis being deprecated and provides guidance on what to use instead.Consider adding the
stacklevelparameter to the warning to help users identify the source of the warning in their code:warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
56-56: No explicit
stacklevelkeyword argument found(B028)
61-67: Helpful format mismatch detection.This check provides a useful warning when tensor shapes suggest potential incorrect usage, helping users debug issues related to tensor format.
Consider adding the
stacklevelparameter to the warning to help users identify the source of the warning in their code:warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
62-62: No explicit
stacklevelkeyword argument found(B028)
fla/ops/simple_gla/fused_recurrent.py (2)
90-95: Added deprecation warning for head_first parameter.The warning correctly informs users that
head_firstis being deprecated and provides guidance on usinghead_first=False.Consider adding the
stacklevelparameter to the warning to help users identify the source of the warning in their code:warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
91-91: No explicit
stacklevelkeyword argument found(B028)
96-102: Added input format validation.This check helps users identify potential issues when their tensor shapes don't match the expected format, improving developer experience.
Consider adding the
stacklevelparameter to the warning to help users identify the source of the warning in their code:warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
97-97: No explicit
stacklevelkeyword argument found(B028)
fla/ops/gla/fused_recurrent.py (2)
92-101: Comprehensive handling of head_first deprecation.The code properly handles all tensors that need to be rearranged, including the conditional rearrangement of
gkandgvif they are provided. The warning message is clear and informative.Consider adding the
stacklevelparameter to the warning to help users identify the source of the warning in their code:warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
93-93: No explicit
stacklevelkeyword argument found(B028)
102-108: Added input format validation.The warning about potential format mismatch helps users identify and debug issues when tensor shapes don't match the expected format.
Consider adding the
stacklevelparameter to the warning to help users identify the source of the warning in their code:warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
103-103: No explicit
stacklevelkeyword argument found(B028)
fla/ops/simple_gla/chunk.py (2)
252-256: Added warning for deprecated head_first parameter.A proper warning message has been added to inform users that the
head_firstparameter is deprecated and will be removed in a future version. This is good practice for maintaining backward compatibility while guiding users toward the new API.Consider adding a specific stacklevel to the warning to properly attribute it to the caller:
- warnings.warn( + warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead." + , stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
253-253: No explicit
stacklevelkeyword argument found(B028)
258-264: Added warning for potential format mismatch.A warning has been added to alert users when the input tensor shape suggests a potential format mismatch. This helps users identify and fix issues with their input data.
Consider adding a specific stacklevel to the warning to properly attribute it to the caller:
- warnings.warn( + warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + , stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
259-259: No explicit
stacklevelkeyword argument found(B028)
fla/ops/gla/chunk.py (2)
1421-1433: Well-implemented deprecation warningGood implementation of deprecation warning with clear guidance for users. The warning explains that
head_firstis deprecated and provides instruction on how to transition to the new format.However, the warning is missing the
stacklevelparameter which helps attribute the warning to the correct call site in user code.Consider adding the
stacklevelparameter:- warnings.warn( + warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead." + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
1422-1422: No explicit
stacklevelkeyword argument found(B028)
1428-1428: No explicit
stacklevelkeyword argument found(B028)
1428-1433: Helpful shape validation warningThe shape validation warning provides useful diagnostics to help users identify when they might be using tensors in the wrong format. This is particularly valuable during the transition period.
Similar to the previous warning, add the
stacklevelparameter:- warnings.warn( + warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
1428-1428: No explicit
stacklevelkeyword argument found(B028)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (21)
fla/ops/abc/chunk.py(1 hunks)fla/ops/common/chunk_h.py(4 hunks)fla/ops/common/chunk_h_parallel.py(7 hunks)fla/ops/common/chunk_o.py(9 hunks)fla/ops/common/fused_recurrent.py(11 hunks)fla/ops/gated_delta_rule/chunk.py(3 hunks)fla/ops/gla/chunk.py(26 hunks)fla/ops/gla/fused_recurrent.py(3 hunks)fla/ops/gla/naive.py(2 hunks)fla/ops/gsa/chunk.py(4 hunks)fla/ops/gsa/fused_recurrent.py(1 hunks)fla/ops/nsa/naive.py(3 hunks)fla/ops/retention/chunk.py(2 hunks)fla/ops/retention/fused_recurrent.py(3 hunks)fla/ops/simple_gla/chunk.py(7 hunks)fla/ops/simple_gla/fused_recurrent.py(3 hunks)fla/ops/utils/cumsum.py(4 hunks)tests/ops/test_cumsum.py(4 hunks)tests/ops/test_gla.py(5 hunks)tests/ops/test_retention.py(3 hunks)tests/ops/test_simple_gla.py(7 hunks)
🧰 Additional context used
🧬 Code Definitions (8)
tests/ops/test_gla.py (2)
fla/ops/gla/chunk.py (1)
chunk_gla(1348-1463)fla/ops/gla/fused_recurrent.py (1)
fused_recurrent_gla(13-141)
fla/ops/retention/fused_recurrent.py (1)
fla/ops/simple_gla/fused_recurrent.py (1)
fused_recurrent_simple_gla(13-133)
fla/ops/retention/chunk.py (1)
fla/ops/simple_gla/chunk.py (1)
chunk_simple_gla(179-294)
fla/ops/common/chunk_h.py (1)
fla/ops/common/utils.py (1)
prepare_chunk_offsets(65-69)
fla/ops/common/fused_recurrent.py (1)
fla/ops/utils/cumsum.py (1)
chunk_global_cumsum(362-381)
fla/ops/common/chunk_h_parallel.py (1)
fla/ops/common/utils.py (2)
prepare_chunk_indices(56-61)prepare_chunk_offsets(65-69)
fla/ops/utils/cumsum.py (1)
fla/ops/common/utils.py (1)
prepare_chunk_indices(56-61)
fla/ops/common/chunk_o.py (1)
fla/ops/common/utils.py (1)
prepare_chunk_indices(56-61)
🪛 Ruff (0.8.2)
fla/ops/retention/fused_recurrent.py
25-25: No explicit stacklevel keyword argument found
(B028)
31-31: No explicit stacklevel keyword argument found
(B028)
fla/ops/retention/chunk.py
56-56: No explicit stacklevel keyword argument found
(B028)
62-62: No explicit stacklevel keyword argument found
(B028)
fla/ops/gla/fused_recurrent.py
93-93: No explicit stacklevel keyword argument found
(B028)
103-103: No explicit stacklevel keyword argument found
(B028)
fla/ops/simple_gla/fused_recurrent.py
91-91: No explicit stacklevel keyword argument found
(B028)
97-97: No explicit stacklevel keyword argument found
(B028)
fla/ops/simple_gla/chunk.py
253-253: No explicit stacklevel keyword argument found
(B028)
259-259: No explicit stacklevel keyword argument found
(B028)
fla/ops/gla/chunk.py
1422-1422: No explicit stacklevel keyword argument found
(B028)
1428-1428: No explicit stacklevel keyword argument found
(B028)
⏰ Context from checks skipped due to timeout of 90000ms (3)
- GitHub Check: test
- GitHub Check: test
- GitHub Check: test
🔇 Additional comments (107)
fla/ops/gla/naive.py (2)
21-21: Tensor dimensions standardized with transpose operation.This change ensures that the tensors are consistently in the format [B, H, T, ...] by transposing dimensions 1 (sequence length) and 2 (heads) before conversion to float. This aligns with the PR objective of removing the
head_firstoption.
41-41: Output tensor format standardized.The output tensor is now transposed back to the expected format before converting to the specified data type, ensuring consistency with the input tensor format adjustment on line 21.
fla/ops/abc/chunk.py (1)
1096-1096: Documentation updated with correct tensor shape description.The updated documentation for the
sparameter correctly clarifies the expected tensor shape based on thehead_firstflag. This ensures users understand the proper input format.fla/ops/gsa/fused_recurrent.py (1)
480-480: Documentation updated with correct tensor shape description.The updated documentation for the
sparameter correctly clarifies the expected tensor shape based on thehead_firstflag. This ensures users understand the proper input format and maintains consistency with similar documentation updates in other files.fla/ops/retention/fused_recurrent.py (4)
4-4: Added necessary imports for deprecation handling.The added imports for
warningsandrearrangefromeinopssupport the new functionality for deprecation warnings and tensor format conversion.Also applies to: 8-8
29-29: Added tensor format conversion for backward compatibility.This change ensures that tensors are converted from head-first format to sequence-first format when
head_first=True, allowing internal processing to be standardized.
37-38: Simplified tensor shape calculation and expansion.The calculation of
snow directly usesq.shape[2]for the range, andgis expanded consistently regardless of thehead_firstparameter value. This simplifies the code and removes conditional logic.
50-52: Added output format conversion for backward compatibility.This change ensures that if
head_first=True, the output tensor is converted back to head-first format before returning, maintaining backward compatibility with existing code.fla/ops/retention/chunk.py (4)
4-4: Import requirements properly added for the new functionality.The added imports for
warningsandrearrangefromeinopsare necessary to support the new deprecation warnings and tensor reshaping features.Also applies to: 8-8
68-69: Simplified tensor shape handling.The computation of
sand expansion ofghave been simplified by directly using the shape ofqtensor, making the code more maintainable.
70-79: Properly removed head_first parameter from function call.Successfully removed the
head_firstparameter from thechunk_simple_glacall, which is consistent with the PR objective of deprecating this parameter.
80-82: Format conversion for backward compatibility.This code correctly handles the backward compatibility by rearranging the output tensor back to head-first format when the parameter is set to
True.tests/ops/test_gla.py (5)
52-55: Updated tensor shapes to use time-first format.The tensor dimensions have been changed from
(B, H, T, D)(head-first) to(B, T, H, D)(time-first) format, aligning with the PR objective to standardize on time-first format.
139-147: Improved readability with multi-line formatting.The function call to
chunk_glahas been reformatted to use multi-line arguments, which improves code readability while maintaining the same functionality.
155-172: Consistently formatted function calls.The function calls to
fused_recurrent_glaare now consistently formatted with one argument per line, improving readability and maintainability.
208-211: Updated variable name and generation logic.The variable
offsetshas been renamed tocu_seqlensto match the parameter name in the called functions, and the logic for generating its values has been updated to ensure a minimum sequence length of 16.
225-226: Consistent parameter naming.The
offsetsparameter has been renamed tocu_seqlensin all function calls, maintaining consistency with the parameter names in the function definitions.Also applies to: 231-232, 248-249
fla/ops/simple_gla/fused_recurrent.py (3)
4-4: Added required imports for new functionality.The imports for
warningsandrearrangefromeinopshave been added to support the deprecation warnings and tensor reshaping functionality.Also applies to: 8-8
120-130: Removed head_first parameter from function call.The
head_firstparameter has been properly removed from thefused_recurrentfunction call, aligning with the PR objective to deprecate this parameter.
131-133: Format conversion for backward compatibility.The code correctly handles backward compatibility by rearranging the output tensor back to head-first format when the parameter is set to
True.fla/ops/gla/fused_recurrent.py (2)
4-4: Added required imports for new functionality.The imports for
warningsandrearrangefromeinopshave been added to support the deprecation warnings and tensor reshaping functionality.Also applies to: 8-8
139-141: Format conversion for backward compatibility.The code correctly handles backward compatibility by rearranging the output tensor back to head-first format when the parameter is set to
True.fla/ops/gated_delta_rule/chunk.py (3)
351-351: Documentation improvedThe assertion message now provides clearer guidance about the expected shape of the beta tensor based on the head_first parameter.
369-369: Enhanced tensor rearrangement flexibilityThe updated rearrangement pattern using ellipsis (
...) allows handling tensors with varying trailing dimensions more elegantly, making the code more robust to different tensor shapes.
386-386: Consistent rearrangement patternThis modification matches the input tensor rearrangement approach (using ellipsis), maintaining consistency in how tensors are handled throughout the function.
tests/ops/test_simple_gla.py (5)
40-40: Parameter renamed for better clarityThe parameter name change from
BTtochunk_sizeimproves code readability by using a more descriptive name that better represents the parameter's purpose.
53-53: Maintained compatibility with variable nameGood approach to assign the new parameter to the existing variable name, preserving the internal implementation while improving the API.
63-63: Improved variable namingChanged variable name from
ltotbetter represents that this dimension corresponds to sequence length or time dimension, improving code readability.
66-66: Updated parameter referenceUpdated to use the new parameter name
chunk_sizefor consistency with the function signature.
73-73: Consistent variable and parameter namingLoop condition updated to use the renamed variable
tand parameterchunk_size, maintaining logical consistency.fla/ops/nsa/naive.py (3)
30-30: Fixed tensor shape documentationCorrected the documentation for the
indicesparameter shape based on thehead_firstparameter, making it more accurate and consistent with the actual implementation.
55-55: More flexible tensor rearrangementUpdated to use ellipsis (
...) in the rearrangement pattern, allowing the function to handle tensors with varying trailing dimensions more flexibly.
95-95: Consistent tensor handlingThe output tensor rearrangement now matches the input handling style, creating consistency throughout the function.
tests/ops/test_retention.py (2)
33-33: Changed test dtype from bfloat16 to float16Similar to the changes in test_simple_gla.py, the test dtype has been standardized to torch.float16 instead of torch.bfloat16.
This appears to be a consistent change across test files. Is there a particular reason for switching from bfloat16 to float16? This might affect numerical precision in some edge cases.
Also applies to: 87-87
107-107: Modified random index generation rangeThe test now uses a broader exclusion range for random permutation (T-16 instead of T-1), which changes how random segments are distributed during testing.
This change might affect the distribution of test cases. Can you ensure this won't reduce test coverage or create edge cases that aren't being tested properly?
tests/ops/test_cumsum.py (4)
55-55: Updated the test dtype from bfloat16 to float16.The test parameterization has been changed to use
torch.float16instead oftorch.bfloat16. This change aligns with similar updates in other test files to standardize the data types used in testing.
75-75: Updated the test dtype from bfloat16 to float16.Consistent with the previous change, this test now uses
torch.float16instead oftorch.bfloat16, standardizing dtype usage across tests.
97-97: Updated the test dtype from bfloat16 to float16.The test parameterization has been changed to use
torch.float16instead oftorch.bfloat16, maintaining consistency across all test functions in this file.
121-121: Updated the test dtype from bfloat16 to float16.The final test function's dtype parameterization has been changed to
torch.float16, completing the standardization of dtype usage throughout the test file.fla/ops/common/chunk_h.py (10)
309-309: Simplified tensor shape unpacking.The code now directly unpacks the shape from
kandvtensors, removing the conditional logic that previously depended on thehead_firstparameter. This simplifies the code and makes tensor shape handling more consistent.
315-316: Streamlined variable assignment for equal-length sequences.The assignment of
N,NS, andsplit_offsetsvariables has been simplified when processing equal-length sequences, improving code readability.
318-318: Simplified variable assignment for variable-length sequences.The assignment of
NandNSvariables has been streamlined when processing variable-length sequences, making the code more concise.
320-320: Modified tensor initialization shape.The tensor
hinitialization has been updated to use a more consistent shape format, removing conditional logic that previously depended on thehead_firstparameter.
343-343: Hardcoded HEAD_FIRST parameter to False.The
HEAD_FIRSTparameter is now explicitly set toFalserather than being dynamically determined, reflecting the removal of thehead_firstoption from the function signature. This simplifies the code and enforces a consistent tensor layout.
364-365: Simplified tensor shape unpacking and HQ assignment.The code now directly unpacks dimensions from tensor shapes and assigns
HQfromq.shape[2], removing conditional logic that previously depended on thehead_firstparameter.
372-373: Streamlined variable assignment for equal-length sequences.The assignment of
N,NS, andsplit_offsetsvariables has been simplified for the backward pass when processing equal-length sequences.
375-375: Simplified variable assignment for variable-length sequences.The assignment of
NandNSvariables has been streamlined for the backward pass when processing variable-length sequences.
378-378: Modified tensor initialization shape for backward pass.The tensor
dhinitialization has been updated to use a more consistent shape format for the backward pass, removing conditional logic that previously depended on thehead_firstparameter.
405-405: Hardcoded HEAD_FIRST parameter to False in backward pass.The
HEAD_FIRSTparameter is now explicitly set toFalsein the backward pass kernel, ensuring consistent tensor layout handling throughout forward and backward operations.fla/ops/gsa/chunk.py (4)
9-9: Added einops import for tensor reshaping.The import of
rearrangeandreducefunctions fromeinopsfacilitates tensor reshaping operations used later in the code, particularly for handling different tensor layouts.
1153-1153: Changed the default value of head_first parameter from True to False.The default value for the
head_firstparameter has been changed fromTruetoFalse, reflecting the deprecation of thehead_firstoption as indicated in the PR title. This aligns with the broader effort to standardize tensor shapes across the codebase.
1165-1165: Updated documentation to reflect the correct tensor shape.The documentation for the
sparameter has been updated to correctly describe the expected tensor shape whenhead_first=False. This ensures that users understand the expected input format after the parameter deprecation.
1271-1272: Added conditional tensor reshaping to maintain backward compatibility.A conditional block has been added to rearrange the output tensor
owhenhead_firstisTrue. This ensures backward compatibility while standardizing on the new tensor layout, allowing existing code that expects the head-first format to continue working.fla/ops/common/chunk_h_parallel.py (11)
14-14: Updated imports to include prepare_chunk_indices function.The import statement has been updated to include
prepare_chunk_indicesfrom the utils module, which is now used to compute indices for chunking instead of relying on the previously passedindicesparameter.
493-493: Simplified tensor shape unpacking.The code now directly unpacks the shape from
kandvtensors, removing the conditional logic that previously depended on thehead_firstparameter. This simplifies the code and makes tensor shape handling more consistent.
499-500: Added automatic chunk indices computation.Instead of requiring indices as a parameter, the function now computes them internally using
prepare_chunk_indiceswhen dealing with variable-length sequences. This simplifies the API while maintaining the same functionality.
502-502: Modified tensor initialization shape.The tensor
hinitialization has been updated to use a more consistent shape format, matching the changes inchunk_h.pyand removing the dependency on thehead_firstparameter.
524-524: Hardcoded HEAD_FIRST parameter to False.The
HEAD_FIRSTparameter is now explicitly set toFalserather than being dynamically determined, which simplifies the code and enforces a consistent tensor layout.
545-545: Hardcoded HEAD_FIRST parameter to False in reduction kernel.The
HEAD_FIRSTparameter is now explicitly set toFalsein the reduction kernel as well, ensuring consistent tensor layout handling throughout all operations.
566-567: Simplified tensor shape unpacking for backward pass.The code now directly unpacks dimensions from tensor shapes for the backward pass, removing conditional logic that previously depended on the
head_firstparameter.
574-575: Added automatic chunk indices computation for backward pass.Similar to the forward pass, the backward pass now computes indices internally using
prepare_chunk_indiceswhen dealing with variable-length sequences, simplifying the API.
578-578: Modified tensor initialization shape for backward pass.The tensor
dhinitialization has been updated to use a more consistent shape format for the backward pass, matching the changes in the forward pass and inchunk_h.py.
604-604: Hardcoded HEAD_FIRST parameter to False in backward parallel kernel.The
HEAD_FIRSTparameter is now explicitly set toFalsein the backward parallel kernel, maintaining consistency with the forward pass changes.
628-628: Hardcoded HEAD_FIRST parameter to False in backward reduction kernel.The
HEAD_FIRSTparameter is now explicitly set toFalsein the backward reduction kernel as well, completing the consistent application of this change throughout all kernel functions.fla/ops/common/fused_recurrent.py (13)
53-54: Removed HEAD_FIRST parameter from kernel function signature.The
HEAD_FIRSTparameter has been removed from the function signature, simplifying the interface and standardizing on a specific tensor layout. This makes the code cleaner and easier to maintain.
66-75: Improved memory access pattern by simplifying pointer calculations.The conditional branching based on
HEAD_FIRSThas been removed, and the pointer calculations are now consistently based on thebosvariable and theREVERSEflag. This change simplifies the logic and likely improves performance by reducing branches in the kernel.
103-112: Simplified pointer updates in the loop.The pointer update calculations have been streamlined to only depend on the
REVERSEflag, which makes the code more maintainable and easier to reason about.
175-184: Consistent pointer initialization in backward kernel.Similar to the forward kernel, the backward kernel now has simplified pointer calculations that only depend on the
bosvariable and theREVERSEflag, making the code more consistent.
213-222: Simplified pointer updates in backward pass.The pointer update logic in the backward pass is now cleaner and consistent with the changes in the forward pass.
227-238: Simplified pointer reinitialization for second phase of backward pass.The pointer calculations for the second phase of the backward pass have been standardized to follow the same pattern as the rest of the code.
265-276: Consistent pointer update logic in second backward pass loop.The pointer update logic in the second backward pass loop is now simpler and consistent with the rest of the code.
294-297: Removed head_first parameter from function signature.The
head_firstparameter has been removed from thefused_recurrent_fwdfunction signature, aligning with the broader refactoring to standardize on a specific tensor layout.
302-303: Simplified final state tensor shape declaration.The shape of the final state tensor
htis now defined directly usingN, H, K, Vdimensions, making the code more straightforward.
347-349: Removed head_first parameter from backward function signature.The
head_firstparameter has been removed from thefused_recurrent_bwdfunction signature, consistent with the changes to the forward function.
395-399: Updated chunk_global_cumsum calls to match new API.The calls to
chunk_global_cumsumhave been updated to remove thehead_firstparameter, aligning with the changes in that function's signature.
421-422: Removed head_first parameter from FusedRecurrentFunction.forward.The
head_firstparameter has been removed from the autograd function's forward method, consistent with the other function signature changes.
471-472: Updated return statement in backward method.The return statement in the backward method has been updated to match the changes in the function signatures.
fla/ops/simple_gla/chunk.py (9)
4-4: Added warnings import for deprecation notices.The
warningsmodule has been imported to support the deprecation warnings added for thehead_firstparameter.
9-9: Added einops import for tensor rearrangement.The
rearrangefunction fromeinopsis now imported to handle tensor rearrangement whenhead_first=Trueis used.
28-28: Simplified chunk_local_cumsum call by removing head_first and indices parameters.The call to
chunk_local_cumsumnow omits thehead_firstandindicesparameters, relying on the internal computation of indices.
133-136: Simplified shape extraction in forward method.The tensor shape extraction in the
forwardmethod is now simpler, directly usingq.shape[1]to determineT.
159-160: Updated context saving in backward method.The context now only saves
chunk_size,scale, andoffsets, removing the no longer neededhead_firstandindicesparameters.
173-174: Simplified chunk_local_cumsum call in backward method.The call to
chunk_local_cumsumin the backward method has been updated to align with the new function signature.
175-175: Simplified return statement in backward method.The return statement in the backward method has been updated to match the changes in function signatures.
257-257: Added tensor rearrangement for backward compatibility.When
head_first=Trueis used, the tensors are automatically rearranged to the standard format expected by the implementation. This ensures backward compatibility during the deprecation period.
292-294: Added tensor rearrangement for output when head_first=True.When
head_first=Trueis used, the output tensor is rearranged back to the head-first format for backward compatibility.fla/ops/utils/cumsum.py (5)
10-10: Added import for prepare_chunk_indices function.The
prepare_chunk_indicesfunction is now imported fromfla.ops.common.utils, allowing the indices to be computed internally rather than passed as parameters.
243-244: Compute indices internally using prepare_chunk_indices.Instead of receiving
indicesas a parameter, the function now computes it internally usingprepare_chunk_indiceswhenoffsetsis provided. This is a good design choice for encapsulation.
274-275: Added internal computation of indices in chunk_local_cumsum_vector.Similar to
chunk_local_cumsum_scalar, thechunk_local_cumsum_vectorfunction now computes indices internally usingprepare_chunk_indices.
391-393: Updated function signature with output_dtype and kwargs.The
chunk_local_cumsumfunction signature has been updated to includeoutput_dtypeas a named parameter and add**kwargsfor backward compatibility.
397-400: Simplified function calls by removing indices parameter.The calls to
chunk_local_cumsum_scalarandchunk_local_cumsum_vectorhave been updated to passoutput_dtypebut omit the now-internally-computedindices.fla/ops/common/chunk_o.py (9)
10-10: Added import for prepare_chunk_indices function.The
prepare_chunk_indicesfunction is now imported fromfla.ops.common.utils, similar to the changes in other files.
467-470: Updated shape extraction and indices computation in chunk_fwd_o.The
chunk_fwd_ofunction now uses a simpler shape extraction approach and computesindicesinternally usingprepare_chunk_indices.
492-493: Set HEAD_FIRST parameter to False in kernel call.The
HEAD_FIRSTparameter is now consistently set toFalsein the kernel call, reflecting the standardization on a specific tensor layout.
507-510: Updated shape extraction and indices computation in chunk_bwd_dv.The
chunk_bwd_dvfunction now uses a simpler shape extraction approach and computesindicesinternally usingprepare_chunk_indices.
541-542: Set HEAD_FIRST parameter to False in backward kernel call.The
HEAD_FIRSTparameter is now consistently set toFalsein the backward kernel call.
556-559: Updated shape extraction and indices computation in chunk_bwd_dv_local.The
chunk_bwd_dv_localfunction now uses a simpler shape extraction approach and computesindicesinternally usingprepare_chunk_indices.
588-589: Set HEAD_FIRST parameter to False in dv_local kernel call.The
HEAD_FIRSTparameter is now consistently set toFalsein thedv_localkernel call.
608-611: Updated shape extraction and indices computation in chunk_bwd_dqkwg.The
chunk_bwd_dqkwgfunction now uses a simpler shape extraction approach and computesindicesinternally usingprepare_chunk_indices.
648-649: Set HEAD_FIRST parameter to False in dqkwg kernel call.The
HEAD_FIRSTparameter is now consistently set toFalsein thedqkwgkernel call.fla/ops/gla/chunk.py (7)
4-4: Good addition of the warnings moduleAdding the warnings module is appropriate for implementing deprecation warnings.
856-857: Improved tensor shape handlingThe explicit extraction of tensor dimensions makes the code more readable and clearer about the expected input format. This is a good change as it standardizes all tensors to use the [B, T, H, K/V] format.
Also applies to: 955-956, 989-990, 1023-1024, 1057-1058, 1206-1207, 1298-1299
858-859: Centralized index preparationUsing the
prepare_chunk_indiceshelper function consistently across the codebase simplifies the logic and reduces code duplication.Also applies to: 957-958, 991-993, 1025-1027, 1059-1060, 1106-1107
863-864: Simplified tensor allocationThe tensor allocation now uses a consistent format without conditional branches for different tensor layouts, which makes the code easier to understand and maintain.
Also applies to: 906-907, 995-996
879-880: Hardcoded HEAD_FIRST parameterSetting
HEAD_FIRST=Falseconsistently throughout the codebase aligns with the PR objective to deprecate and remove thehead_firstoption.Also applies to: 900-901, 925-926, 940-941, 977-978, 1009-1010, 1044-1045, 1085-1086, 1135-1136
1153-1154: Consistent cumsum calculationThe
chunk_local_cumsumfunction is now used consistently with the same parameter structure, improving code readability.Also applies to: 1209-1210
1451-1463: Improved ChunkGLAFunction.apply call and backward compatibilityThe refactored function call structure is cleaner and more maintainable. The conditional rearrangement of the output tensor when
head_first=Trueensures backward compatibility for users who haven't migrated yet.
| @pytest.mark.parametrize("D", test_d_list) | ||
| @pytest.mark.parametrize("gate_logit_normalizer", test_gate_list) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("dtype", [torch.float16]) |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Changed test dtype from bfloat16 to float16
All test cases now use torch.float16 instead of torch.bfloat16. This appears to be a deliberate change across multiple test functions.
Also applies to: 190-190, 266-266, 337-337
🏁 Script executed:
#!/bin/bash
# Check if there's a pattern of similar changes across the codebase
rg -A 1 "@pytest.mark.parametrize\(\"dtype\"" --glob "tests/**/*.py"Length of output: 10380
Action Required: Clarify and Unify Test Dtype Usage in test_simple_gla.py
The change at line 118 in tests/ops/test_simple_gla.py updates the dtype parameter to use torch.float16 instead of torch.bfloat16. However, note the following based on our search:
- Several parameterizations in this file (e.g., at lines 118, 190, 266) now use
torch.float16, but there is at least one instance (around line 337) where the dtype is still set as[torch.float, torch.bfloat16]. - Moreover, a quick scan of other test files (such as tests/test_model.py and tests/ops/test_utils.py) reveals that not all tests have switched from
torch.bfloat16totorch.float16.
Recommendations:
- If the intent was to apply a uniform change across all test cases in test_simple_gla.py, update the remaining parameterization(s) (e.g., change
[torch.float, torch.bfloat16]to the appropriate value) for consistency. - Consider whether this change should be applied globally across the entire test suite. If so, update other files that still use
torch.bfloat16or, if intentional, add a comment clarifying the rationale behind the mixed usage.
d503477 to
17d5c21
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
fla/ops/gsa/chunk.py (3)
10-10: Importing reduce
The importreducefromeinopsis introduced. Verify that it's actually used. If it is unused, consider removing to keep imports minimal.
550-550: Extract shape from k
Explicitly unpackingB, T, H, K, V = *k.shape, v.shape[-1]might cause confusion ifkorvshapes differ from expected. Consider validating shapes or adding an assertion for clarity.
1006-1006: Defaulting head_first to False
Changing the default fromTruetoFalseis a breaking API change for existing code. Communicate this clearly in release notes to avoid user confusion.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
fla/ops/gla/chunk.py(36 hunks)fla/ops/gsa/chunk.py(28 hunks)fla/ops/gsa/fused_recurrent.py(10 hunks)tests/ops/test_gsa.py(2 hunks)
🔇 Additional comments (19)
tests/ops/test_gsa.py (2)
147-147: Validate minimum T
The new rangetorch.randperm(T - 16)requires thatTbe at least 16. Please confirm or enforce thatT >= 16to avoid runtime errors.
329-329: Validate minimum T
Similarly, ensureT >= 16for the random permutation logic to work without errors.fla/ops/gsa/chunk.py (10)
4-4: Importing warnings
You've introduced an import for thewarningsmodule, presumably for deprecation notices. Looks good.
13-14: Added chunk utilities
You're now importingprepare_chunk_indicesand thechunk_gla_*modules. Ensure these references are used consistently and the imports are necessary.
92-93: Revised pointer initialization
These lines update howp_oandp_Apointers are created for block operations. Using(bos * HQ + i_hq) * Vand(bos * HQ + i_hq) * BTappear correct, but please confirm alignment remains consistent with the expected tensor shapes.
554-554: Hard-coded assumption for HQ
SettingHQ = q.shape[2]assumes the shape is[B, T, HQ, ...]. If the user passes[B, H, T, ...], this could lead to out-of-bound indexing or shape mismatches.
557-557: Conditional index preparation
indices = prepare_chunk_indices(offsets, BT)is a neat approach for varlen sequences. Just ensure that downstream usage covers all edge cases whenoffsetsisNone.
1084-1087: Deprecation warning and rearrange
You've added a deprecation warning forhead_firstand force a shape rearrange from[b h t ...]to[b t h ...]. This is a helpful transitional measure.
1088-1094: Warning on shape mismatch
The warning clarifies potential user error if[B, T, H, ...]is accidentally passed in head-first format. Good practice to reduce confusion.
1292-1296: Deprecated head_first in chunk_gsa
This block again warns about deprecation and rearranges shapes ifhead_firstis used. Keeping a consistent approach to deprecatinghead_firstis beneficial.
1297-1304: Format mismatch warning
Similar logic to lines 1088-1094. Good job preserving consistent user notifications across files.
1322-1322: Returning final_state
This updated return statement ensures both output and final_state are produced. Make sure any downstream usage is updated accordingly.fla/ops/gsa/fused_recurrent.py (2)
4-4: Add warnings import
Importingwarningsto handle deprecation or user notifications is consistent with the approach in other files.
10-10: Added rearrange import
Importedrearrangefromeinops. Make sure any newly introduced shape transformations are correct.fla/ops/gla/chunk.py (5)
4-4: Addwarningsimport
This import is consistent with the other GLA modules for deprecation notices.
10-10: Use of rearrange
Introducingrearrangefromeinopsallows flexible tensor shape transformations.
1292-1296: Deprecation block for head_first
This block mirrors the pattern used elsewhere, providing a warning and rearranging shapes ifhead_firstis used.
1297-1304: Shape mismatch checks
These lines add a warning ifq.shape[1] < q.shape[2]whenhead_first=False. Good for preventing user confusion.
1322-1322: Return assignment
Returningo, final_statefrom theChunkGLAFunctioncall. Confirm that all call sites expect this new structure.
Summary by CodeRabbit
Documentation
Refactor
User Warnings
Tests