Skip to content

Fix #390: Add missing fwd_prepare_T function#563

Closed
liqiongyu wants to merge 1 commit intofla-org:mainfrom
liqiongyu:fix-issue-390-fwd-prepare-t
Closed

Fix #390: Add missing fwd_prepare_T function#563
liqiongyu wants to merge 1 commit intofla-org:mainfrom
liqiongyu:fix-issue-390-fwd-prepare-t

Conversation

@liqiongyu
Copy link
Copy Markdown

@liqiongyu liqiongyu commented Aug 13, 2025

Description

This PR fixes #390 by implementing the missing fwd_prepare_T function.

Problem

  • fla/ops/delta_rule/parallel.py tried to import fwd_prepare_T from wy_fast.py
  • The function didn't exist, causing ImportError
  • There was also a tensor format mismatch between head-first and seq-first layouts

Solution

  • Added fwd_prepare_T function in fla/ops/delta_rule/wy_fast.py
  • Handles format conversion from head-first [B, H, T, K] to seq-first [B, T, H, K]
  • Returns result in correct head-first format

Changes

  • Added fwd_prepare_T implementation with proper format handling
  • Created comprehensive tests in tests/ops/test_parallel_delta.py
  • Documented tensor format expectations

Testing

  • test_fwd_prepare_T passes successfully
  • test_parallel_delta_rule verifies the function works correctly

Fixes #390

Summary by CodeRabbit

  • New Features

    • Added a public API to prepare the transformation matrix for the delta-rule, supporting chunked processing and variable-length sequences. Compatible with head-first tensors and integrates with existing delta-rule operations. No breaking changes.
  • Tests

    • Expanded test coverage for the parallel delta-rule across multiple shapes and dtypes, including attention outputs.
    • Added tests for the new transformation preparation API, validating expected output shapes. Certain platform-specific skips applied.

- Implement fwd_prepare_T in wy_fast.py to resolve import error
- Handle tensor format conversion between head-first and seq-first
- Add comprehensive tests for parallel_delta_rule and fwd_prepare_T
- Document tensor format expectations

This fixes the ImportError when importing fwd_prepare_T from
fla.ops.delta_rule.wy_fast and properly handles the format mismatch
between head-first [B, H, T, K] and seq-first [B, T, H, K] tensors.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Aug 13, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Caution

Review failed

The pull request is closed.

Walkthrough

Adds a new public function fwd_prepare_T to fla/ops/delta_rule/wy_fast.py to compute a transformation matrix used in the delta-rule path, including support for chunking and variable-length sequences. Introduces tests for parallel_delta_rule and fwd_prepare_T, verifying importability, execution, and output shapes.

Changes

Cohort / File(s) Summary
Delta-rule ops
fla/ops/delta_rule/wy_fast.py
Adds fwd_prepare_T(k, beta, chunk_size, cu_seqlens=None) computing A via chunk_scaled_dot_kkt_fwd (fp32) and solve_tril, with head-first/seq-first transposes; returns [B, H, T, chunk_size].
Delta-rule tests
tests/ops/test_parallel_delta.py
New tests: shape/execution checks for parallel_delta_rule with output_attentions, and import/execution check for fwd_prepare_T; skipped on Intel platforms.

Sequence Diagram(s)

sequenceDiagram
  participant Caller
  participant wy_fast.fwd_prepare_T
  participant chunk_scaled_dot_kkt_fwd
  participant solve_tril

  Caller->>wy_fast.fwd_prepare_T: k[B,H,T,K], beta[B,H,T], chunk_size, cu_seqlens?
  wy_fast.fwd_prepare_T->>wy_fast.fwd_prepare_T: transpose to [B,T,H,K]/[B,T,H]
  wy_fast.fwd_prepare_T->>chunk_scaled_dot_kkt_fwd: compute KKT chunks (fp32)
  chunk_scaled_dot_kkt_fwd-->>wy_fast.fwd_prepare_T: chunked lower-triangular system
  wy_fast.fwd_prepare_T->>solve_tril: solve (I - tril(beta K K^T))^{-1}
  solve_tril-->>wy_fast.fwd_prepare_T: A in seq-first
  wy_fast.fwd_prepare_T->>wy_fast.fwd_prepare_T: transpose back to [B,H,T,chunk_size]
  wy_fast.fwd_prepare_T-->>Caller: A[B,H,T,chunk_size]
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

Assessment against linked issues

Objective Addressed Explanation
Define and expose fwd_prepare_T so importing fla.ops.delta_rule.parallel does not raise ImportError (#390)

Assessment against linked issues: Out-of-scope changes

(none)

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

I thump my paw—new T is here, hooray!
Matrices march in chunked array ballet.
No ImportError nips my whiskered cheer,
Triangular tales solved crystal-clear.
From burrowed code, I hop and beam—
Fast WY flows like a carrot dream. 🥕🐇


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 797bf72 and e895d44.

📒 Files selected for processing (2)
  • fla/ops/delta_rule/wy_fast.py (1 hunks)
  • tests/ops/test_parallel_delta.py (1 hunks)
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @liqiongyu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses issue #390 by implementing the previously missing fwd_prepare_T function, which was causing an ImportError in fla/ops/delta_rule/parallel.py. The new function correctly computes the transformation matrix required for the parallel delta rule algorithm and resolves a critical tensor format mismatch by handling conversions between head-first and seq-first layouts. Comprehensive tests have been added to validate the functionality and ensure the stability of the parallel delta rule implementation.

Highlights

  • Missing Function Implementation: The fwd_prepare_T function has been added to fla/ops/delta_rule/wy_fast.py. This function is responsible for preparing the transformation matrix A for delta rule computation, specifically A = (I - tril(beta * K * K^T))^{-1}.
  • Tensor Format Handling: The newly implemented fwd_prepare_T function now correctly handles tensor format conversions, transforming input k and beta from head-first [B, H, T, K] to seq-first [B, T, H, K] for internal processing, and then converting the result back to head-first [B, H, T, chunk_size].
  • New Test File and Coverage: A new test file, tests/ops/test_parallel_delta.py, has been introduced. It includes test_parallel_delta_rule to verify the main parallel delta rule function and test_fwd_prepare_T to specifically test the newly added fwd_prepare_T function, ensuring its correct import, execution, and output shape.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully addresses the ImportError by adding the missing fwd_prepare_T function and correctly handling the tensor format conversions. The inclusion of new tests is a great addition to ensure the fix is working as expected. I've included a few minor suggestions related to code style and formatting to improve overall code quality.


def fwd_prepare_T(
k: torch.Tensor,
beta: torch.Tensor,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a trailing whitespace at the end of this line. It's good practice to remove it to maintain code cleanliness and adhere to the PEP 8 style guide.1

Suggested change
beta: torch.Tensor,
beta: torch.Tensor,

Style Guide References

Footnotes

  1. PEP 8: Avoid trailing whitespace. It can be confusing and some editors are configured to remove it automatically.

)
def test_parallel_delta_rule(
B: int,
H: int,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a trailing whitespace at the end of this line. It's good practice to remove it to maintain code cleanliness and adhere to the PEP 8 style guide.1

Suggested change
H: int,
H: int,

Style Guide References

Footnotes

  1. PEP 8: Avoid trailing whitespace. It can be confusing and some editors are configured to remove it automatically.

# Check output shape
# After our fix, fwd_prepare_T returns [B, H, T, chunk_size] (head-first format)
expected_shape = (B, H, T, chunk_size)
assert A.shape == expected_shape, f"Expected shape {expected_shape}, got {A.shape}" No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The file is missing a newline at the end. It's a good practice to end files with a single newline character for consistency and to avoid issues with certain tools, as recommended by PEP 8.1

    assert A.shape == expected_shape, f"Expected shape {expected_shape}, got {A.shape}"\n

Style Guide References

Footnotes

  1. PEP 8: All files should end in a single newline.

@liqiongyu liqiongyu closed this Aug 13, 2025
@liqiongyu liqiongyu deleted the fix-issue-390-fwd-prepare-t branch August 13, 2025 14:36
@liqiongyu liqiongyu restored the fix-issue-390-fwd-prepare-t branch August 13, 2025 14:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] fwd_prepare_T cannot be imported

1 participant