Fix #390: Add missing fwd_prepare_T function#563
Fix #390: Add missing fwd_prepare_T function#563liqiongyu wants to merge 1 commit intofla-org:mainfrom
Conversation
- Implement fwd_prepare_T in wy_fast.py to resolve import error - Handle tensor format conversion between head-first and seq-first - Add comprehensive tests for parallel_delta_rule and fwd_prepare_T - Document tensor format expectations This fixes the ImportError when importing fwd_prepare_T from fla.ops.delta_rule.wy_fast and properly handles the format mismatch between head-first [B, H, T, K] and seq-first [B, T, H, K] tensors.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. Caution Review failedThe pull request is closed. WalkthroughAdds a new public function fwd_prepare_T to fla/ops/delta_rule/wy_fast.py to compute a transformation matrix used in the delta-rule path, including support for chunking and variable-length sequences. Introduces tests for parallel_delta_rule and fwd_prepare_T, verifying importability, execution, and output shapes. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant wy_fast.fwd_prepare_T
participant chunk_scaled_dot_kkt_fwd
participant solve_tril
Caller->>wy_fast.fwd_prepare_T: k[B,H,T,K], beta[B,H,T], chunk_size, cu_seqlens?
wy_fast.fwd_prepare_T->>wy_fast.fwd_prepare_T: transpose to [B,T,H,K]/[B,T,H]
wy_fast.fwd_prepare_T->>chunk_scaled_dot_kkt_fwd: compute KKT chunks (fp32)
chunk_scaled_dot_kkt_fwd-->>wy_fast.fwd_prepare_T: chunked lower-triangular system
wy_fast.fwd_prepare_T->>solve_tril: solve (I - tril(beta K K^T))^{-1}
solve_tril-->>wy_fast.fwd_prepare_T: A in seq-first
wy_fast.fwd_prepare_T->>wy_fast.fwd_prepare_T: transpose back to [B,H,T,chunk_size]
wy_fast.fwd_prepare_T-->>Caller: A[B,H,T,chunk_size]
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Assessment against linked issues
Assessment against linked issues: Out-of-scope changes(none) Possibly related PRs
Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Summary of Changes
Hello @liqiongyu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses issue #390 by implementing the previously missing fwd_prepare_T function, which was causing an ImportError in fla/ops/delta_rule/parallel.py. The new function correctly computes the transformation matrix required for the parallel delta rule algorithm and resolves a critical tensor format mismatch by handling conversions between head-first and seq-first layouts. Comprehensive tests have been added to validate the functionality and ensure the stability of the parallel delta rule implementation.
Highlights
- Missing Function Implementation: The
fwd_prepare_Tfunction has been added tofla/ops/delta_rule/wy_fast.py. This function is responsible for preparing the transformation matrixAfor delta rule computation, specificallyA = (I - tril(beta * K * K^T))^{-1}. - Tensor Format Handling: The newly implemented
fwd_prepare_Tfunction now correctly handles tensor format conversions, transforming inputkandbetafrom head-first[B, H, T, K]to seq-first[B, T, H, K]for internal processing, and then converting the result back to head-first[B, H, T, chunk_size]. - New Test File and Coverage: A new test file,
tests/ops/test_parallel_delta.py, has been introduced. It includestest_parallel_delta_ruleto verify the main parallel delta rule function andtest_fwd_prepare_Tto specifically test the newly addedfwd_prepare_Tfunction, ensuring its correct import, execution, and output shape.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request successfully addresses the ImportError by adding the missing fwd_prepare_T function and correctly handling the tensor format conversions. The inclusion of new tests is a great addition to ensure the fix is working as expected. I've included a few minor suggestions related to code style and formatting to improve overall code quality.
|
|
||
| def fwd_prepare_T( | ||
| k: torch.Tensor, | ||
| beta: torch.Tensor, |
There was a problem hiding this comment.
There is a trailing whitespace at the end of this line. It's good practice to remove it to maintain code cleanliness and adhere to the PEP 8 style guide.1
| beta: torch.Tensor, | |
| beta: torch.Tensor, |
Style Guide References
Footnotes
-
PEP 8: Avoid trailing whitespace. It can be confusing and some editors are configured to remove it automatically. ↩
| ) | ||
| def test_parallel_delta_rule( | ||
| B: int, | ||
| H: int, |
There was a problem hiding this comment.
There is a trailing whitespace at the end of this line. It's good practice to remove it to maintain code cleanliness and adhere to the PEP 8 style guide.1
| H: int, | |
| H: int, |
Style Guide References
Footnotes
-
PEP 8: Avoid trailing whitespace. It can be confusing and some editors are configured to remove it automatically. ↩
| # Check output shape | ||
| # After our fix, fwd_prepare_T returns [B, H, T, chunk_size] (head-first format) | ||
| expected_shape = (B, H, T, chunk_size) | ||
| assert A.shape == expected_shape, f"Expected shape {expected_shape}, got {A.shape}" No newline at end of file |
There was a problem hiding this comment.
The file is missing a newline at the end. It's a good practice to end files with a single newline character for consistency and to avoid issues with certain tools, as recommended by PEP 8.1
assert A.shape == expected_shape, f"Expected shape {expected_shape}, got {A.shape}"\nStyle Guide References
Footnotes
-
PEP 8: All files should end in a single newline. ↩
Description
This PR fixes #390 by implementing the missing
fwd_prepare_Tfunction.Problem
fla/ops/delta_rule/parallel.pytried to importfwd_prepare_Tfromwy_fast.pySolution
fwd_prepare_Tfunction infla/ops/delta_rule/wy_fast.pyChanges
fwd_prepare_Timplementation with proper format handlingtests/ops/test_parallel_delta.pyTesting
test_fwd_prepare_Tpasses successfullytest_parallel_delta_ruleverifies the function works correctlyFixes #390
Summary by CodeRabbit
New Features
Tests