Skip to content

[WIP] Remove head_first option#339

Merged
yzhangcs merged 3 commits intomainfrom
head_first
Apr 8, 2025
Merged

[WIP] Remove head_first option#339
yzhangcs merged 3 commits intomainfrom
head_first

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Apr 8, 2025

Summary by CodeRabbit

  • Refactor

    • Streamlined APIs across multiple components by removing deprecated parameters and conditional logic for input tensor formats. Consolidated pointer calculations and unified tensor shape handling to improve consistency and maintainability.
  • Tests & Chores

    • Updated test configurations to shift from an older low-precision format to float16, ensuring consistent numerical behavior and performance across test suites.
    • Added informative deprecation warnings to guide users toward the updated input formats.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 8, 2025

Walkthrough

This update refactors numerous functions across the ops and tests modules by removing the optional indices and head_first parameters. Functions in attention, delta rule, gated delta rule, generalized delta rule, RWKV6, TTT, and related modules now internally compute indices using the provided offsets (via prepare_chunk_indices) and apply a unified tensor shape handling approach. Additionally, several conditional branches and pointer arithmetic based on the removed parameters have been simplified. Test files have also been updated to use torch.float16 instead of torch.bfloat16 and to adjust list order and formatting.

Changes

File(s) Change Summary
fla/ops/attn/parallel.py Removed indices from parallel_attn_fwd & parallel_attn_bwd; now computing indices internally from offsets.
fla/ops/common/[chunk_delta_h.py, chunk_h_split.py, chunk_scaled_dot_kkt.py] Removed HEAD_FIRST/head_first parameters and their conditional logic; added preprocess_qkw in chunk_delta_h.py; unified tensor shape unpacking.
fla/ops/delta_rule/[chunk.py, fused_recurrent.py, wy_fast.py] Eliminated indices and head_first parameters; inserted warnings for deprecated usage; adjusted pointer arithmetic to rely solely on offsets.
fla/ops/gated_delta_rule/[chunk.py, fused_recurrent.py, wy_fast.py] Removed indices and head_first parameters; added warnings to notify users about tensor format expectations; streamlined calls to lower-level functions.
fla/ops/generalized_delta_rule/dplr/... Excised HEAD_FIRST/head_first from multiple functions; standardized pointer, index, and offset computations using prepare_chunk_indices; unified tensor shape handling.
fla/ops/[rwkv6/*, simple_gla/parallel.py, ttt/*, ttt/fused_chunk.py] Consistently removed HEAD_FIRST/head_first parameters; simplified tensor indexing and pointer calculations across modules.
fla/ops/utils/solve_tril.py Eliminated the HEAD_FIRST parameter and its associated conditional logic; updated matrix index and stride computations.
tests/ops/* Updated dtype parameters from torch.bfloat16 to torch.float16; reordered test lists and refined formatting and transposition adjustments in various test functions.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant Function
    Caller->>Function: Call function with tensors & offsets
    Note over Function: Removed optional indices/head_first parameters
    Function->>Function: If offsets provided, compute indices via prepare_chunk_indices
    Function-->>Caller: Return computed tensor results
Loading

Possibly related PRs

Poem

I'm a rabbit with a skip in my hop,
Removing extra params, I just can't stop!
Offsets guide me as indices take shape,
Cleaner code is what we now make.
Carrots and code, both fresh and bright—
Leaping through logic with pure delight! 🥕🐇

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@yzhangcs yzhangcs merged commit 627a44d into main Apr 8, 2025
5 checks passed
@yzhangcs yzhangcs deleted the head_first branch April 8, 2025 20:55
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.

Actionable comments posted: 1

🧹 Nitpick comments (23)
fla/ops/forgetting_attn/parallel.py (2)

52-55: Add stacklevel to the warning message

The warning about deprecated head_first parameter should include a stacklevel parameter to ensure the warning points to the user's code rather than this internal function.

 warnings.warn(
     "head_first is deprecated and will be removed in a future version. "
-    "Please use head_first=False for now instead."
+    "Please use head_first=False for now instead.",
+    stacklevel=2
 )
🧰 Tools
🪛 Ruff (0.8.2)

52-52: No explicit stacklevel keyword argument found

(B028)


57-63: Add stacklevel to the format mismatch warning

The format mismatch warning should include a stacklevel parameter to properly attribute the warning to the calling code.

 warnings.warn(
     f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
     "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
     "when head_first=False was specified. "
-    "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
+    "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+    stacklevel=2
 )
🧰 Tools
🪛 Ruff (0.8.2)

58-58: No explicit stacklevel keyword argument found

(B028)

fla/ops/ttt/chunk.py (2)

1390-1393: Add stacklevel to the warning message

The warning about deprecated head_first parameter should include a stacklevel parameter to ensure the warning points to the user's code rather than this internal function.

 warnings.warn(
     "head_first is deprecated and will be removed in a future version. "
-    "Please use head_first=False for now instead."
+    "Please use head_first=False for now instead.",
+    stacklevel=2
 )
🧰 Tools
🪛 Ruff (0.8.2)

1390-1390: No explicit stacklevel keyword argument found

(B028)


1396-1401: Add stacklevel to the format mismatch warning

The format mismatch warning should include a stacklevel parameter to properly attribute the warning to the calling code.

 warnings.warn(
     f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
     "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
     "when head_first=False was specified. "
-    "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
+    "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+    stacklevel=2
 )
🧰 Tools
🪛 Ruff (0.8.2)

1396-1396: No explicit stacklevel keyword argument found

(B028)

fla/ops/ttt/fused_chunk.py (1)

792-804: Added appropriate deprecation warnings for head_first parameter

Good addition of warning messages to notify users about:

  1. The deprecation of the head_first parameter
  2. Potential format mismatches when tensor shapes suggest head-first was used incorrectly

Consider adding a stacklevel parameter to the warnings.warn() calls to ensure the warnings point to the user's code rather than the library code.

-        warnings.warn(
+        warnings.warn(
             "head_first is deprecated and will be removed in a future version. "
             "Please use head_first=False for now instead."
-        )
+            stacklevel=2
+        )

-        warnings.warn(
+        warnings.warn(
             f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
             "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
             "when head_first=False was specified. "
             "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

793-793: No explicit stacklevel keyword argument found

(B028)


799-799: No explicit stacklevel keyword argument found

(B028)

fla/ops/rwkv6/chunk.py (1)

1267-1279: Added deprecation warnings with clear guidance

Proper warning messages have been added to notify users about:

  1. The deprecation of the head_first parameter
  2. Potential tensor shape issues when head_first=False but the input shape suggests head-first format

As with the previous file, consider adding the stacklevel parameter to warnings to improve the developer experience.

-        warnings.warn(
+        warnings.warn(
             "head_first is deprecated and will be removed in a future version. "
             "Please use head_first=False for now instead."
-        )
+            stacklevel=2
+        )

-        warnings.warn(
+        warnings.warn(
             f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
             "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
             "when head_first=False was specified. "
             "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

1268-1268: No explicit stacklevel keyword argument found

(B028)


1274-1274: No explicit stacklevel keyword argument found

(B028)

fla/ops/gated_delta_rule/fused_recurrent.py (1)

285-297: Added consistent deprecation warnings for head_first parameter

The warnings added here are consistent with those in other files, providing clear guidance to users about:

  1. The deprecation of the head_first parameter
  2. Potential format mismatches when tensor shapes suggest incorrect usage

As with the previous files, consider adding stacklevel=2 to the warnings to improve the developer experience by pointing to the user's code rather than the library code.

-        warnings.warn(
+        warnings.warn(
             "head_first is deprecated and will be removed in a future version. "
             "Please use head_first=False for now instead."
-        )
+            stacklevel=2
+        )

-        warnings.warn(
+        warnings.warn(
             f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
             "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
             "when head_first=False was specified. "
             "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

286-286: No explicit stacklevel keyword argument found

(B028)


292-292: No explicit stacklevel keyword argument found

(B028)

fla/ops/rwkv6/fused_recurrent.py (1)

637-649: Added deprecation warnings for head_first parameter.

These warning messages properly inform users about the deprecation of the head_first parameter and potential format mismatches. However, they're missing the stacklevel parameter which helps indicate the correct source of the warning in the user's code.

-        warnings.warn(
-            "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
-        )
+        warnings.warn(
+            "head_first is deprecated and will be removed in a future version. "
+            "Please use head_first=False for now instead.",
+            stacklevel=2
+        )
-        warnings.warn(
-            f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). "
-            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-            "when head_first=False was specified. "
-            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+        warnings.warn(
+            f"Input tensor shape suggests potential format mismatch: seq_len ({r.shape[1]}) < num_heads ({r.shape[2]}). "
+            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+            "when head_first=False was specified. "
+            "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

638-638: No explicit stacklevel keyword argument found

(B028)


644-644: No explicit stacklevel keyword argument found

(B028)

fla/ops/delta_rule/chunk.py (1)

305-317: Added deprecation warnings for head_first parameter.

These warning messages properly inform users about the deprecation of the head_first parameter and potential format mismatches. However, they're missing the stacklevel parameter which helps indicate the correct source of the warning in the user's code.

-        warnings.warn(
-            "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
-        )
+        warnings.warn(
+            "head_first is deprecated and will be removed in a future version. "
+            "Please use head_first=False for now instead.",
+            stacklevel=2
+        )
-        warnings.warn(
-            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-            "when head_first=False was specified. "
-            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+        warnings.warn(
+            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
+            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+            "when head_first=False was specified. "
+            "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

306-306: No explicit stacklevel keyword argument found

(B028)


312-312: No explicit stacklevel keyword argument found

(B028)

fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py (1)

252-255: Revised dimension extraction and chunk-index logic.
Defining (B, T, H, K) = k.shape and using prepare_chunk_indices is more standardized. Consider verifying that chunk sizes smaller or larger than T still work properly.

fla/ops/common/chunk_delta_h.py (1)

255-264: Backward kernel pointer offsets.
Shifting dh, dv, dv2, q, k, d, and do by (bos * H + i_h) or (boh * H + i_h) unifies the indexing pattern with the forward pass. Double-check that each pointer shift lines up with how memory is laid out for multi-head variable-length sequences.

fla/ops/generalized_delta_rule/iplr/fused_recurrent.py (1)

420-432: Deprecation warnings for head_first.
You might want to set stacklevel=2 or higher so that the warning message points to the user’s call site, satisfying the static analysis hint (B028). Additionally, rearranging with einops is correct, but be sure the user is aware of any performance cost.

Apply this small tweak to each warnings.warn call:

-warnings.warn(
+warnings.warn(
    "head_first is deprecated ...",
    UserWarning,
+   stacklevel=2
)
🧰 Tools
🪛 Ruff (0.8.2)

421-421: No explicit stacklevel keyword argument found

(B028)


427-427: No explicit stacklevel keyword argument found

(B028)

fla/ops/simple_gla/parallel.py (1)

695-707: Added deprecation warnings for head_first parameter.

Good practice to warn users about the upcoming removal of the head_first parameter, but the warnings should include a stacklevel parameter to ensure they point to the correct location in user code.

-        warnings.warn(
-            "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
-        )
+        warnings.warn(
+            "head_first is deprecated and will be removed in a future version. "
+            "Please use head_first=False for now instead.",
+            stacklevel=2
+        )

Similarly for the second warning:

-        warnings.warn(
-            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-            "when head_first=False was specified. "
-            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+        warnings.warn(
+            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
+            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+            "when head_first=False was specified. "
+            "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

696-696: No explicit stacklevel keyword argument found

(B028)


702-702: No explicit stacklevel keyword argument found

(B028)

fla/ops/generalized_delta_rule/dplr/fused_recurrent.py (1)

252-263: Missing stacklevel in deprecation warnings.

The warnings correctly inform users about the deprecation of head_first and potential shape mismatches, but they should include a stacklevel parameter to ensure the warnings refer to the caller's code rather than the library itself.

-        warnings.warn(
-            "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
-        )
+        warnings.warn(
+            "head_first is deprecated and will be removed in a future version. "
+            "Please use head_first=False for now instead.",
+            stacklevel=2
+        )
-        warnings.warn(
-            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-            "when head_first=False was specified. "
-            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+        warnings.warn(
+            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
+            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+            "when head_first=False was specified. "
+            "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

252-252: No explicit stacklevel keyword argument found

(B028)


258-258: No explicit stacklevel keyword argument found

(B028)

fla/ops/generalized_delta_rule/dplr/chunk.py (1)

321-333: Missing stacklevel in deprecation warnings.

The warnings correctly inform users about the deprecation of head_first and potential shape mismatches, but they should include a stacklevel parameter to ensure the warnings refer to the caller's code rather than the library itself.

-        warnings.warn(
-            "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
-        )
+        warnings.warn(
+            "head_first is deprecated and will be removed in a future version. "
+            "Please use head_first=False for now instead.",
+            stacklevel=2
+        )
-        warnings.warn(
-            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-            "when head_first=False was specified. "
-            "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
-        )
+        warnings.warn(
+            f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
+            "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+            "when head_first=False was specified. "
+            "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+            stacklevel=2
+        )
🧰 Tools
🪛 Ruff (0.8.2)

322-322: No explicit stacklevel keyword argument found

(B028)


328-328: No explicit stacklevel keyword argument found

(B028)

fla/ops/gated_delta_rule/chunk.py (3)

317-321: Add stacklevel parameter to the warning call

For better user experience, add the stacklevel parameter to ensure the warning points to the caller's code rather than the library internals.

- warnings.warn(
-     "head_first is deprecated and will be removed in a future version. "
-     "Please use head_first=False for now instead."
- )
+ warnings.warn(
+     "head_first is deprecated and will be removed in a future version. "
+     "Please use head_first=False for now instead.",
+     stacklevel=2
+ )
🧰 Tools
🪛 Ruff (0.8.2)

318-318: No explicit stacklevel keyword argument found

(B028)


323-329: Add stacklevel parameter to the warning call

Similar to the previous warning, add the stacklevel parameter to ensure the warning points to the caller's code.

- warnings.warn(
-     f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-     "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-     "when head_first=False was specified. "
-     "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
- )
+ warnings.warn(
+     f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
+     "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+     "when head_first=False was specified. "
+     "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+     stacklevel=2
+ )
🧰 Tools
🪛 Ruff (0.8.2)

324-324: No explicit stacklevel keyword argument found

(B028)


251-284: Update docstring to reflect parameter removal

The docstring still mentions the head_first parameter, but it's being deprecated and removed from the function signature. Consider updating the docstring to indicate its deprecated status.

# In the docstring
- head_first (Optional[bool]):
-     Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
-     Default: `False`.
+ head_first (Optional[bool]):
+     [DEPRECATED] Whether the inputs are in the head-first format.
+     This parameter is deprecated and will be removed in a future version.
+     Default: `False`.
fla/ops/delta_rule/fused_recurrent.py (3)

517-520: Add stacklevel parameter to the warning call

Add the stacklevel parameter to ensure the warning points to the caller's code rather than the library internals.

- warnings.warn(
-     "head_first is deprecated and will be removed in a future version. "
-     "Please use head_first=False for now instead."
- )
+ warnings.warn(
+     "head_first is deprecated and will be removed in a future version. "
+     "Please use head_first=False for now instead.",
+     stacklevel=2
+ )
🧰 Tools
🪛 Ruff (0.8.2)

517-517: No explicit stacklevel keyword argument found

(B028)


523-528: Add stacklevel parameter to the warning call

Similar to the previous warning, add the stacklevel parameter for better user experience.

- warnings.warn(
-     f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-     "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-     "when head_first=False was specified. "
-     "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
- )
+ warnings.warn(
-     f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-     "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-     "when head_first=False was specified. "
-     "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+     stacklevel=2
+ )
🧰 Tools
🪛 Ruff (0.8.2)

523-523: No explicit stacklevel keyword argument found

(B028)


332-336: Use ternary operator for cleaner code

Consider using a ternary operator for creating the db tensor, as suggested by the static analysis tool. This makes the code more concise.

- if beta_vector:
-     db = q.new_empty(NV, NK, B, T, H, V)
- else:
-     db = q.new_empty(NV, B, T, H)
+ db = q.new_empty(NV, NK, B, T, H, V) if beta_vector else q.new_empty(NV, B, T, H)
🧰 Tools
🪛 Ruff (0.8.2)

332-335: Use ternary operator db = q.new_empty(NV, NK, B, T, H, V) if beta_vector else q.new_empty(NV, B, T, H) instead of if-else-block

Replace if-else-block with db = q.new_empty(NV, NK, B, T, H, V) if beta_vector else q.new_empty(NV, B, T, H)

(SIM108)

fla/ops/generalized_delta_rule/iplr/chunk.py (2)

465-468: Add stacklevel parameter to the warning call

Add the stacklevel parameter to ensure the warning points to the caller's code rather than the library internals.

- warnings.warn(
-     "head_first is deprecated and will be removed in a future version. "
-     "Please use head_first=False for now instead."
- )
+ warnings.warn(
+     "head_first is deprecated and will be removed in a future version. "
+     "Please use head_first=False for now instead.",
+     stacklevel=2
+ )
🧰 Tools
🪛 Ruff (0.8.2)

465-465: No explicit stacklevel keyword argument found

(B028)


471-476: Add stacklevel parameter to the warning call

Similar to the previous warning, add the stacklevel parameter for better user experience.

- warnings.warn(
-     f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
-     "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
-     "when head_first=False was specified. "
-     "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
- )
+ warnings.warn(
+     f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
+     "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
+     "when head_first=False was specified. "
+     "Please verify your input tensor format matches the expected shape [B, T, H, ...].",
+     stacklevel=2
+ )
🧰 Tools
🪛 Ruff (0.8.2)

471-471: No explicit stacklevel keyword argument found

(B028)

🛑 Comments failed to post (1)
fla/ops/common/chunk_delta_h.py (1)

299-299: 🛠️ Refactor suggestion

Potential gating boundary check.
last_idx = min((i_t + 1) * BT, T) - 1 can become -1 if T=0, though that’s presumably not a valid scenario in practice. If the code might be used with empty or zero-length sequences, consider adding a safeguard.

zhiyuan1i pushed a commit that referenced this pull request Apr 8, 2025
zhiyuan1i added a commit that referenced this pull request Apr 8, 2025
* [RWKV7] add `input_precision` param

* [Test] Add skip params to CI (#336) [skip test]

* [Triton] Fix `Call Undecorated gather Functio` on Triton 3.1.0 (#329) [skip test]

* [Deprecated] Remove `head_first` option in gla variants (#337)

* [CI]: fix for pass [skip test]

* [Test] Ensure most tests on Triton 3.2.0 and add `4096` seq_length in tests [skip test] (#300)

* [FoX] Merge code to FlashAttention |  support batch inference (#333)

* [FoX] Merge code to FlashAttention |  support batch inference

* [Inference] More flexible cache definitions for attn state

* Fix dim mismatches and improved unpad_input fn

* Fix bugs of cu_seqlens

---------

Co-authored-by: Yu Zhang <yzhang.cs@outlook.com>

* [DeltaNet] Delete `head_first` option for all (#338)

* [WIP] Remove head_first option (#339)

---------

Co-authored-by: Yu Zhang <yzhang.cs@outlook.com>
Co-authored-by: Songlin Yang <yangsl66@mit.edu>
@coderabbitai coderabbitai bot mentioned this pull request Jan 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant