Conversation
WalkthroughThis update refactors numerous functions across the ops and tests modules by removing the optional Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Function
Caller->>Function: Call function with tensors & offsets
Note over Function: Removed optional indices/head_first parameters
Function->>Function: If offsets provided, compute indices via prepare_chunk_indices
Function-->>Caller: Return computed tensor results
Possibly related PRs
Poem
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Caution
Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.
Actionable comments posted: 1
🧹 Nitpick comments (23)
fla/ops/forgetting_attn/parallel.py (2)
52-55: Add stacklevel to the warning messageThe warning about deprecated
head_firstparameter should include astacklevelparameter to ensure the warning points to the user's code rather than this internal function.warnings.warn( "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." + "Please use head_first=False for now instead.", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
52-52: No explicit
stacklevelkeyword argument found(B028)
57-63: Add stacklevel to the format mismatch warningThe format mismatch warning should include a
stacklevelparameter to properly attribute the warning to the calling code.warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
58-58: No explicit
stacklevelkeyword argument found(B028)
fla/ops/ttt/chunk.py (2)
1390-1393: Add stacklevel to the warning messageThe warning about deprecated
head_firstparameter should include astacklevelparameter to ensure the warning points to the user's code rather than this internal function.warnings.warn( "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." + "Please use head_first=False for now instead.", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
1390-1390: No explicit
stacklevelkeyword argument found(B028)
1396-1401: Add stacklevel to the format mismatch warningThe format mismatch warning should include a
stacklevelparameter to properly attribute the warning to the calling code.warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
1396-1396: No explicit
stacklevelkeyword argument found(B028)
fla/ops/ttt/fused_chunk.py (1)
792-804: Added appropriate deprecation warnings for head_first parameterGood addition of warning messages to notify users about:
- The deprecation of the
head_firstparameter- Potential format mismatches when tensor shapes suggest head-first was used incorrectly
Consider adding a
stacklevelparameter to thewarnings.warn()calls to ensure the warnings point to the user's code rather than the library code.- warnings.warn( + warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead." - ) + stacklevel=2 + ) - warnings.warn( + warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
793-793: No explicit
stacklevelkeyword argument found(B028)
799-799: No explicit
stacklevelkeyword argument found(B028)
fla/ops/rwkv6/chunk.py (1)
1267-1279: Added deprecation warnings with clear guidanceProper warning messages have been added to notify users about:
- The deprecation of the
head_firstparameter- Potential tensor shape issues when head_first=False but the input shape suggests head-first format
As with the previous file, consider adding the
stacklevelparameter to warnings to improve the developer experience.- warnings.warn( + warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead." - ) + stacklevel=2 + ) - warnings.warn( + warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
1268-1268: No explicit
stacklevelkeyword argument found(B028)
1274-1274: No explicit
stacklevelkeyword argument found(B028)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
285-297: Added consistent deprecation warnings for head_first parameterThe warnings added here are consistent with those in other files, providing clear guidance to users about:
- The deprecation of the
head_firstparameter- Potential format mismatches when tensor shapes suggest incorrect usage
As with the previous files, consider adding
stacklevel=2to the warnings to improve the developer experience by pointing to the user's code rather than the library code.- warnings.warn( + warnings.warn( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead." - ) + stacklevel=2 + ) - warnings.warn( + warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
286-286: No explicit
stacklevelkeyword argument found(B028)
292-292: No explicit
stacklevelkeyword argument found(B028)
fla/ops/rwkv6/fused_recurrent.py (1)
637-649: Added deprecation warnings for head_first parameter.These warning messages properly inform users about the deprecation of the
head_firstparameter and potential format mismatches. However, they're missing thestacklevelparameter which helps indicate the correct source of the warning in the user's code.- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
638-638: No explicit
stacklevelkeyword argument found(B028)
644-644: No explicit
stacklevelkeyword argument found(B028)
fla/ops/delta_rule/chunk.py (1)
305-317: Added deprecation warnings for head_first parameter.These warning messages properly inform users about the deprecation of the
head_firstparameter and potential format mismatches. However, they're missing thestacklevelparameter which helps indicate the correct source of the warning in the user's code.- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
306-306: No explicit
stacklevelkeyword argument found(B028)
312-312: No explicit
stacklevelkeyword argument found(B028)
fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py (1)
252-255: Revised dimension extraction and chunk-index logic.
Defining(B, T, H, K) = k.shapeand usingprepare_chunk_indicesis more standardized. Consider verifying that chunk sizes smaller or larger than T still work properly.fla/ops/common/chunk_delta_h.py (1)
255-264: Backward kernel pointer offsets.
Shiftingdh,dv,dv2,q,k,d, anddoby(bos * H + i_h)or(boh * H + i_h)unifies the indexing pattern with the forward pass. Double-check that each pointer shift lines up with how memory is laid out for multi-head variable-length sequences.fla/ops/generalized_delta_rule/iplr/fused_recurrent.py (1)
420-432: Deprecation warnings forhead_first.
You might want to setstacklevel=2or higher so that the warning message points to the user’s call site, satisfying the static analysis hint (B028). Additionally, rearranging witheinopsis correct, but be sure the user is aware of any performance cost.Apply this small tweak to each
warnings.warncall:-warnings.warn( +warnings.warn( "head_first is deprecated ...", UserWarning, + stacklevel=2 )🧰 Tools
🪛 Ruff (0.8.2)
421-421: No explicit
stacklevelkeyword argument found(B028)
427-427: No explicit
stacklevelkeyword argument found(B028)
fla/ops/simple_gla/parallel.py (1)
695-707: Added deprecation warnings for head_first parameter.Good practice to warn users about the upcoming removal of the
head_firstparameter, but the warnings should include astacklevelparameter to ensure they point to the correct location in user code.- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )Similarly for the second warning:
- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
696-696: No explicit
stacklevelkeyword argument found(B028)
702-702: No explicit
stacklevelkeyword argument found(B028)
fla/ops/generalized_delta_rule/dplr/fused_recurrent.py (1)
252-263: Missing stacklevel in deprecation warnings.The warnings correctly inform users about the deprecation of
head_firstand potential shape mismatches, but they should include astacklevelparameter to ensure the warnings refer to the caller's code rather than the library itself.- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
252-252: No explicit
stacklevelkeyword argument found(B028)
258-258: No explicit
stacklevelkeyword argument found(B028)
fla/ops/generalized_delta_rule/dplr/chunk.py (1)
321-333: Missing stacklevel in deprecation warnings.The warnings correctly inform users about the deprecation of
head_firstand potential shape mismatches, but they should include astacklevelparameter to ensure the warnings refer to the caller's code rather than the library itself.- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
322-322: No explicit
stacklevelkeyword argument found(B028)
328-328: No explicit
stacklevelkeyword argument found(B028)
fla/ops/gated_delta_rule/chunk.py (3)
317-321: Add stacklevel parameter to the warning callFor better user experience, add the stacklevel parameter to ensure the warning points to the caller's code rather than the library internals.
- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
318-318: No explicit
stacklevelkeyword argument found(B028)
323-329: Add stacklevel parameter to the warning callSimilar to the previous warning, add the stacklevel parameter to ensure the warning points to the caller's code.
- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
324-324: No explicit
stacklevelkeyword argument found(B028)
251-284: Update docstring to reflect parameter removalThe docstring still mentions the
head_firstparameter, but it's being deprecated and removed from the function signature. Consider updating the docstring to indicate its deprecated status.# In the docstring - head_first (Optional[bool]): - Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `False`. + head_first (Optional[bool]): + [DEPRECATED] Whether the inputs are in the head-first format. + This parameter is deprecated and will be removed in a future version. + Default: `False`.fla/ops/delta_rule/fused_recurrent.py (3)
517-520: Add stacklevel parameter to the warning callAdd the stacklevel parameter to ensure the warning points to the caller's code rather than the library internals.
- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
517-517: No explicit
stacklevelkeyword argument found(B028)
523-528: Add stacklevel parameter to the warning callSimilar to the previous warning, add the stacklevel parameter for better user experience.
- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
523-523: No explicit
stacklevelkeyword argument found(B028)
332-336: Use ternary operator for cleaner codeConsider using a ternary operator for creating the
dbtensor, as suggested by the static analysis tool. This makes the code more concise.- if beta_vector: - db = q.new_empty(NV, NK, B, T, H, V) - else: - db = q.new_empty(NV, B, T, H) + db = q.new_empty(NV, NK, B, T, H, V) if beta_vector else q.new_empty(NV, B, T, H)🧰 Tools
🪛 Ruff (0.8.2)
332-335: Use ternary operator
db = q.new_empty(NV, NK, B, T, H, V) if beta_vector else q.new_empty(NV, B, T, H)instead ofif-else-blockReplace
if-else-block withdb = q.new_empty(NV, NK, B, T, H, V) if beta_vector else q.new_empty(NV, B, T, H)(SIM108)
fla/ops/generalized_delta_rule/iplr/chunk.py (2)
465-468: Add stacklevel parameter to the warning callAdd the stacklevel parameter to ensure the warning points to the caller's code rather than the library internals.
- warnings.warn( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." - ) + warnings.warn( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
465-465: No explicit
stacklevelkeyword argument found(B028)
471-476: Add stacklevel parameter to the warning callSimilar to the previous warning, add the stacklevel parameter for better user experience.
- warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...]." - ) + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2 + )🧰 Tools
🪛 Ruff (0.8.2)
471-471: No explicit
stacklevelkeyword argument found(B028)
🛑 Comments failed to post (1)
fla/ops/common/chunk_delta_h.py (1)
299-299: 🛠️ Refactor suggestion
Potential gating boundary check.
last_idx = min((i_t + 1) * BT, T) - 1can become-1ifT=0, though that’s presumably not a valid scenario in practice. If the code might be used with empty or zero-length sequences, consider adding a safeguard.
* [RWKV7] add `input_precision` param * [Test] Add skip params to CI (#336) [skip test] * [Triton] Fix `Call Undecorated gather Functio` on Triton 3.1.0 (#329) [skip test] * [Deprecated] Remove `head_first` option in gla variants (#337) * [CI]: fix for pass [skip test] * [Test] Ensure most tests on Triton 3.2.0 and add `4096` seq_length in tests [skip test] (#300) * [FoX] Merge code to FlashAttention | support batch inference (#333) * [FoX] Merge code to FlashAttention | support batch inference * [Inference] More flexible cache definitions for attn state * Fix dim mismatches and improved unpad_input fn * Fix bugs of cu_seqlens --------- Co-authored-by: Yu Zhang <yzhang.cs@outlook.com> * [DeltaNet] Delete `head_first` option for all (#338) * [WIP] Remove head_first option (#339) --------- Co-authored-by: Yu Zhang <yzhang.cs@outlook.com> Co-authored-by: Songlin Yang <yangsl66@mit.edu>
Summary by CodeRabbit
Refactor
Tests & Chores