Skip to content

[WY representation] Faster lower triangle inverse#289

Merged
yzhangcs merged 14 commits intomainfrom
faster-deltanet
Apr 3, 2025
Merged

[WY representation] Faster lower triangle inverse#289
yzhangcs merged 14 commits intomainfrom
faster-deltanet

Conversation

@sustcsonglin
Copy link
Copy Markdown
Collaborator

@sustcsonglin sustcsonglin commented Apr 3, 2025

Summary by CodeRabbit

  • New Features
    • Introduced high-performance routines for efficient scaled dot product and triangular matrix inversion operations.
  • Refactor
    • Streamlined internal processes by consolidating multiple computational steps for improved clarity and performance.
  • Tests
    • Expanded testing coverage with updated parameters and tighter precision thresholds to ensure robust functionality.
    • Added new tests for triangular matrix inversion and variable-length sequences.
    • Implemented platform-specific test execution criteria to enhance reliability.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2025

Walkthrough

This pull request introduces two new files that implement specialized Triton kernels: one for performing chunked scaled dot products and another for lower triangular matrix inversion. Additionally, new tests validate these functionalities, and an existing function in the delta rule module is refactored to utilize the new kernel functions with updated parameters.

Changes

File(s) Change Summary
fla/ops/common/chunk_scaled_dot_kkt.py
fla/ops/utils/solve_tril.py
New Triton kernel files. Implement functions for chunked scaled dot products and for computing the inverse of lower triangular matrices via specialized kernels with autotuning and boundary checks.
tests/ops/test_solve_tril.py
tests/ops/test_delta.py
New and modified tests. Added test suites for validating triangular matrix inversion and updated test parameters (e.g., precision thresholds, head_first parameter ordering, and variable-length sequence handling).
fla/ops/delta_rule/wy_fast.py Refactored function fwd_prepare_wy_repr. Replaced legacy kernel calls with calls to chunk_scaled_dot_kkt_fwd and solve_tril, and updated the default value of the head_first parameter from True to False.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as Caller
    participant FPR as fwd_prepare_wy_repr
    participant CSD as chunk_scaled_dot_kkt_fwd
    participant ST as solve_tril

    Caller->>FPR: Call fwd_prepare_wy_repr(k, v, beta, offsets, indices, head_first, chunk_size)
    FPR->>CSD: Compute scaled dot product (A)
    CSD-->>FPR: Return tensor A
    FPR->>ST: Perform triangular inversion on A
    ST-->>FPR: Return inverted matrix A
    FPR-->>Caller: Return w, u, A
Loading

Possibly related PRs

  • [DeltaNet] WY repr speedup #279: The changes in the main PR introduce new functions for scaled dot product operations, while the retrieved PR modifies the fwd_prepare_wy_repr function to incorporate calls to these new functions, indicating a direct relationship at the code level.

Suggested reviewers

  • yzhangcs

Poem

I'm a hopping rabbit in a code-filled glen,
Munching kernels and merging functions again and again.
With chunked dots and triangles turned right,
My lines of code leap like stars in the night.
Carrots and code—what a delightful sight!


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a8aa59e and 67d86a1.

📒 Files selected for processing (2)
  • tests/ops/test_delta.py (8 hunks)
  • tests/ops/test_dplr_delta.py (3 hunks)
🔇 Additional comments (12)
tests/ops/test_dplr_delta.py (3)

12-12: Adding device_platform import looks good

This import enables platform-specific test skipping, which is important for handling platform-dependent test behaviors.


317-320: Good addition of skip condition for Intel platform

This skip condition prevents test failures on Intel platforms where Triton implementation appears to have issues. This is a better approach than having tests fail unpredictably.


409-412: Consistent skip condition for Intel platform

The same skip condition applied to the variable length test ensures consistent behavior across related test functions.

tests/ops/test_delta.py (9)

11-11: Device platform import looks good

Similar to the other test file, this import enables platform-specific test skipping.


21-21: Test sequence length adjustment looks appropriate

Removing the shortest sequence length (1) and adding a longer one (512) better focuses the tests on realistic sequence lengths relevant to the lower triangle inverse optimization.


33-33: Preference for seq-first format in tests

Reordering the test parameters to check head_first=False first suggests a preference for sequence-first format, which is consistent with modern transformer implementations.


38-41: Appropriate skip condition for Intel platforms

Consistent with the other test file, this skip condition prevents predictable failures on Intel platforms.


51-51: Good addition of random seed

Setting a fixed random seed ensures test reproducibility, which is important for debugging and consistent test results.


95-101: Adjusted tolerances for numerical stability

The increased tolerance values (from 0.005 to 0.006 or 0.008) appropriately account for potential minor numerical differences in the optimized implementation while maintaining adequate precision for testing.


114-117: Consistent skip condition for variable length tests

Consistent with other tests, this prevents running on Intel platforms where Triton failures are expected.


131-131: Improved minimum segment length

Increasing the minimum segment length to 16 (by using torch.arange(16, T)) helps prevent potential numerical instability or algorithmic edge cases with very short sequences, which is particularly relevant for matrix operations like lower triangular inversion.


175-179: Consistent tolerance adjustments

The tolerance adjustments are applied consistently across all gradient assertions, maintaining the same level of precision requirements throughout the tests.

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/ops/test_solve_tril.py (2)

30-33: Clarify or re-check the skip condition for chunk-varlen tests.
The skip condition is triggered when SKIP_TEST_CHUNK_VARLEN == "0", yet the reason states "TEST_CHUNK_VARLEN is enabled." This logic can be confusing. Consider ensuring the naming matches the intended behavior, such as skipping the test when the corresponding environment variable is actually enabled (truthy).


56-78: Reduce nested looping for improved test performance.
In test_solve_tril_varlen, lines 72–77 implement a nested loop over sequence chunk boundaries. While correct, it may slow down large-scale test runs. Consider vectorized or batched approaches if performance becomes a bottleneck.

fla/ops/utils/solve_tril.py (1)

25-71: Consider partial-chunk edge handling in solve_tril_16x16_kernel.
The kernel handles 16×16 blocks via a loop (line 63), assuming at least partial coverage. For smaller remaining chunks or minimal T, carefully verify that boundary conditions always hold. This can improve safety and clarity of partial-block logic.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e5b46b2 and 027d951.

📒 Files selected for processing (3)
  • fla/ops/common/chunk_scaled_dot_kkt.py (1 hunks)
  • fla/ops/utils/solve_tril.py (1 hunks)
  • tests/ops/test_solve_tril.py (1 hunks)
🧰 Additional context used
🧬 Code Definitions (1)
tests/ops/test_solve_tril.py (2)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
  • chunk_scaled_dot_kkt_fwd (69-99)
fla/ops/utils/solve_tril.py (1)
  • solve_tril (234-281)
🪛 GitHub Actions: nvidia-h100-ci
tests/ops/test_solve_tril.py

[error] 43-43: TypeError: 'NoneType' object is not subscriptable


[error] 43-43: TypeError: 'NoneType' object is not subscriptable


[error] 43-43: TypeError: 'NoneType' object is not subscriptable


[error] 45-45: TypeError: 'NoneType' object is not subscriptable


[error] 45-45: TypeError: 'NoneType' object is not subscriptable


[error] 45-45: TypeError: 'NoneType' object is not subscriptable

⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: test
🔇 Additional comments (5)
tests/ops/test_solve_tril.py (1)

34-53:

Details

❓ Verification inconclusive

Investigate potential NoneType subscript issue in indexing.
The pipeline failures suggest a "TypeError: 'NoneType' object is not subscriptable" at lines 43 and 45. Possible causes include an unexpected zero or undefined value of T, or an uninitialized A in certain environments.

Use the following script to scan for scenarios where T might be set to zero or incorrectly inferred:

If any environment or code path sets T to 0, additional conditional guards or error handling might be required.


🏁 Script executed:

#!/bin/bash
# Check for the presence or usage of T=0 in the codebase.
rg "test_t_list" -A 3
rg -A3 "if compiled_mode"

Length of output: 27139


Action: Confirm that solve_tril always returns a valid tensor before indexing its result.

After reviewing the parameters across the codebase, we see that in this test the value for T is consistently set to nonzero values (e.g., 64, 128, 200, 300, 500). This indicates that the "NoneType" issue is unlikely to stem from an unexpected zero value of T or an improperly computed A in this test. The reported error at lines 43 and 45 may instead be due to a code path inside the solve_tril function returning None.

Please verify the following:

  • Confirm that solve_tril does not return None under any circumstances.
  • Add input validation or an explicit guard in solve_tril if there’s any possibility of a missing return value.
🧰 Tools
🪛 GitHub Actions: nvidia-h100-ci

[error] 43-43: TypeError: 'NoneType' object is not subscriptable


[error] 43-43: TypeError: 'NoneType' object is not subscriptable


[error] 43-43: TypeError: 'NoneType' object is not subscriptable


[error] 45-45: TypeError: 'NoneType' object is not subscriptable


[error] 45-45: TypeError: 'NoneType' object is not subscriptable


[error] 45-45: TypeError: 'NoneType' object is not subscriptable

fla/ops/common/chunk_scaled_dot_kkt.py (2)

69-99: Looks solid for chunk-based scaled dot-product computations.
The high-level wrapper sets up shapes, prepares indices, and calls the Triton kernel. Implementation is clear and logically consistent with the design.


22-61:

Details

❓ Verification inconclusive

Ensure robust handling of T, K, and chunk-size corner cases.
Inside chunk_scaled_dot_kkt_fwd_kernel, indexing uses T, K, BT, BK with boundary checks. Ensure those checks protect against invalid or negative dimensions if unexpected values slip in, e.g., through input misconfiguration.

You can confirm by searching for calls to chunk_scaled_dot_kkt_fwd combined with dimension checks:


🏁 Script executed:

#!/bin/bash
rg -A 5 "chunk_scaled_dot_kkt_fwd("

Length of output: 138


Action: Verify and enhance dimension validations in chunk_scaled_dot_kkt_fwd_kernel

The review comment’s intent is still valid: ensure that the boundary checks for T, K, BT, and BK are robust against unexpected or negative input dimensions. Note that the initial regex command produced a parse error because the parenthesis wasn’t escaped. To properly search for function calls in the codebase, please use an escaped pattern such as:

rg -A 5 'chunk_scaled_dot_kkt_fwd\('

Points to address:

  • Confirm that all call sites passing values to chunk_scaled_dot_kkt_fwd_kernel enforce proper validations of inputs (especially for T, K, and chunk sizes).
  • If these values can be misconfigured, consider adding explicit error handling or assertions at the entry point of the kernel.
  • Re-run the corrected search command to verify that all calls to chunk_scaled_dot_kkt_fwd utilize appropriate dimension checks.
fla/ops/utils/solve_tril.py (2)

85-133: Merging 16×16 blocks into 32×32 blocks is well-structured.
The stepwise logic (lines 126–133) for merging blocks appears correct and well-bounded. The negative sign usage for inverting the sub-block accumulation aligns with the approach.


234-281: Implementation of solve_tril function meets stated objectives.
The function introduces chunked inverse logic for 16×16, 32×32, and 64×64 blocks, returning (I + A)^-1 as promised. Assertions enforce matrix shape, data type, and contiguity. This is consistent with the documented contract.

Copy link
Copy Markdown
Collaborator

@zhiyuan1i zhiyuan1i left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sonta, I think os.getenv("SKIP_TEST_CHUNK_VARLEN") needs to be 1 here, maybe I can make it more clear.

@pytest.mark.parametrize("H", test_h_list)
@pytest.mark.parametrize("cu_seqlens", test_t_varlen_list)
@pytest.mark.parametrize("chunk_size", [64, 32, 16])
@pytest.mark.skipif(
    os.getenv("SKIP_TEST_CHUNK_VARLEN") == "1",
    reason="Skipping test because TEST_CHUNK_VARLEN is enabled"
)
def test_solve_tril_varlen(H, cu_seqlens, chunk_size):
    T = cu_seqlens[-1]

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
fla/ops/utils/solve_tril.py (2)

286-311: Test code looks good, but contains a commented-out breakpoint

The test implementation correctly validates the implementation against a reference solution using PyTorch's inverse function.

However, there's a commented-out breakpoint on line 309 that should be removed before merging:

-    # breakpoint()

235-239: Enhance docstring with parameter descriptions

The docstring is clear about what the function does, but it would be helpful to add descriptions for each parameter, especially for cu_seqlens and head_first which may not be self-explanatory to users unfamiliar with the codebase.

def solve_tril(A, cu_seqlens=None, head_first=True, output_dtype=torch.float32):
    """
    Compute the inverse of the lower triangular matrix
    A should be strictly lower triangular. Please make sure A.triu() == 0.
    return: (I + A)^-1
+
+    Parameters:
+    - A: Lower triangular matrix of shape [B, H, T, BT] if head_first=True else [B, T, H, BT]
+    - cu_seqlens: Optional. Cumulative sequence lengths for variable-length sequences
+    - head_first: If True, input is [B, H, T, BT], else [B, T, H, BT]
+    - output_dtype: Data type of the output matrix
    """
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 027d951 and 0442fdc.

📒 Files selected for processing (2)
  • fla/ops/utils/solve_tril.py (1 hunks)
  • tests/ops/test_solve_tril.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/ops/test_solve_tril.py
🧰 Additional context used
🪛 GitHub Actions: lint
fla/ops/utils/solve_tril.py

[error] 271-271: flake8: E231 missing whitespace after ','

⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: test
🔇 Additional comments (3)
fla/ops/utils/solve_tril.py (3)

13-71: Matrix inversion kernel implementation looks good

The solve_tril_16x16_kernel function is well-implemented with proper autotuning configuration and boundary checks. The algorithm correctly computes the inverse of a 16x16 lower triangular matrix with appropriate memory layout handling.


73-133: Implementation of 32x32 inverse assembly kernel looks solid

The kernel effectively merges 16x16 inverse blocks into a 32x32 inverse matrix using the appropriate block matrix operations. Good use of boundary checks and type conversions.


135-232: 64x64 inverse assembly implementation is correct

The kernel correctly implements the block matrix operations required to merge 16x16 inverse blocks into a 64x64 inverse matrix. All boundary checks and type conversions are properly handled.

Comment thread fla/ops/utils/solve_tril.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
fla/ops/utils/solve_tril.py (3)

135-145: Autotuning key includes unused parameters

The autotuning key includes parameters ('K', 'BK', 'BC') that don't appear in the function signature which may lead to suboptimal autotuning.

 @triton.autotune(
     configs=[
         triton.Config({}, num_warps=num_warps, num_stages=num_stages)
         for num_warps in [2, 4]
         for num_stages in [2, 3, 4]
     ],
-    key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'USE_OFFSETS'],
+    key=['H', 'HEAD_FIRST', 'USE_OFFSETS'],
 )

234-243: Consider adding assertion to validate strictly lower triangular property

The function documentation states that A should be strictly lower triangular, but there's no assertion to validate this property.

 def solve_tril(A, cu_seqlens=None, head_first=True, output_dtype=torch.float32):
     """
     Compute the inverse of the lower triangular matrix
     A should be strictly lower triangular. Please make sure A.triu() == 0.
     return: (I + A)^-1
     """
     assert A.shape[-1] in [16, 32, 64]
     assert A.dtype == torch.float32, "A should be float32."
     assert A.is_contiguous(), "A should be contiguous."
+    assert torch.all(torch.triu(A, 1) == 0), "A should be strictly lower triangular (A.triu(1) == 0)."
     if head_first is True:

309-309: Remove debugging breakpoint

There's a commented-out breakpoint statement which should be removed before the final submission.

     from fla.ops.utils.testing import assert_close

-    # breakpoint()
     assert_close("solve_tril", A_inv, A_inv_ref, 0.0001)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0442fdc and 566a503.

📒 Files selected for processing (1)
  • fla/ops/utils/solve_tril.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: test
🔇 Additional comments (5)
fla/ops/utils/solve_tril.py (5)

1-11: Clean and appropriate imports

The imports are appropriately organized and include all necessary libraries for implementing Triton kernels for matrix operations.


13-35: Well-configured kernel with proper autotuning

The solve_tril_16x16_kernel has appropriate Triton decorators:

  • Effective heuristics for conditional execution
  • Good autotuning configuration with varying warps and stages
  • Clear parameter definitions with proper type annotations

36-71: Efficient implementation of 16x16 inverse calculation

The kernel is well-implemented with:

  • Proper handling of program IDs for parallel execution
  • Effective management of variable batch sizes and dimensions
  • Correct use of Triton's block pointer functionality
  • Appropriate boundary checking
  • Numerically sound algorithm for inverting the lower triangular matrix

271-271: Corrected whitespace issue

The whitespace after the comma has been fixed, addressing the issue flagged in the previous review.


286-311: Comprehensive test case

The test logic is thorough and provides good validation:

  • Creates realistic test data
  • Handles padding and reshaping appropriately
  • Computes a reference solution using torch.inverse
  • Compares results with suitable tolerance

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
fla/ops/common/chunk_scaled_dot_kkt.py (4)

24-38: Consider adding more documentation or type hints.
While Triton constraints are often implicit, inline explanations of each parameter would improve clarity for future maintainers.


48-53: Boundary checks.
Only the first dimension is in boundary_check=(0,). Confirm if the second dimension also needs boundary enforcement when data is non-contiguous.


72-92: Chunk size tunability.
chunk_size=64 is a practical default, but exploring other chunk sizes might yield better performance for different sequence lengths.


93-105: Automatic detection for USE_OFFSETS.
The heuristic-based approach is convenient but can be non-obvious to new contributors. Adding a short comment or docstring would help clarify.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 912c740 and 389a0ad.

📒 Files selected for processing (1)
  • fla/ops/common/chunk_scaled_dot_kkt.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (6)
fla/ops/common/chunk_scaled_dot_kkt.py (6)

1-3: File header looks good.
No issues spotted in the initial comments and license line.


4-11: Imports are appropriate.
No conflicts or missing imports identified for these lines.


13-23: Autotune and heuristics configuration.
The configs for num_warps and num_stages seem reasonable. Ensure performance is tested thoroughly across different data shapes.


39-47: Edge case handling for offsets.
If eos < bos ever occurs, the effective T would be negative, causing unexpected behavior. Consider an assertion or safe check.


54-63: Loop iteration bounds.
For very large K, ensure the range in for i_k in range(tl.cdiv(K, BK)) doesn't risk integer overflow or overly large loops.


64-69: Lower-triangular masking logic.
Masking out the upper triangle is correct for KKT or triangular-related computations.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
fla/ops/common/chunk_scaled_dot_kkt.py (2)

54-63: Potential performance consideration when multiplying large chunks.
Inside the loop, you compute b_kb = b_k * b_beta[:, None] and then do tl.dot(...). For very large K, ensure that the partial sum fits in floating-point precision. If necessary, consider employing accumulation in a higher-precision type (e.g., FP32 or FP64) before storing the final result, to minimize numerical errors.


72-82: Docstring can elaborate on the lower-triangular constraint and offsets.
The docstring states "Compute beta * k * k^T," but the result is further processed to yield a strictly lower-triangular matrix. Consider mentioning this triangular masking step, and also how the optional cu_seqlens offsets play into chunked variable-length operations.

- r"""
- Compute beta * k * k^T
- """
+ r"""
+ Compute beta * k * k^T and then masks upper triangle to produce a strictly lower-triangular result.
+ If cu_seqlens is provided, the kernel uses chunk offsets for variable-length sequences.
+ """
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 389a0ad and c2674c1.

📒 Files selected for processing (1)
  • fla/ops/common/chunk_scaled_dot_kkt.py (1 hunks)
🔇 Additional comments (5)
fla/ops/common/chunk_scaled_dot_kkt.py (5)

1-3: File header looks fine.
The UTF-8 declaration, copyright line, and spacing are properly set.


13-24: Triton decorators are well-organized.
The usage of @triton.heuristics and @triton.autotune for controlling runtime parameters looks correct. Specifying key configurations with num_warps and num_stages is consistent with Triton best practices to tune performance.


48-53: Verify correct pointer arithmetic for beta in HEAD_FIRST vs. non-HEAD_FIRST cases.
While the indexing logic for p_beta appears consistent, it could be helpful to document these pointer layouts more explicitly in the docstring or surrounding comments, to clarify why the stride changes between (T,) vs. (H,).

Would you like a script to scan for all instances of beta usage across the repository to confirm consistent indexing?


64-69: Double-check the strictly lower-triangular masking.
Line 64 enforces all diagonal and upper-triangular elements to 0:

b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)

This zeroes out the diagonal as well. Confirm that this behavior (strictly lower-triangular) is intended, or whether the diagonal should be preserved.


92-105: Ensure all shape corner cases are covered.
The kernel launch uses:

chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](...)

with NT = triton.cdiv(T, BT) if cu_seqlens is None. Verify that extremely small or zero-sized T does not cause off-by-one errors or boundary issues, particularly when the chunk size is greater than T.

Comment on lines +39 to +47
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider renaming i_t to avoid overshadowing.
Here, i_t is read from tl.program_id(0), but then overwritten via the offsets logic:

i_n, i_t = tl.load(indices + i_t * 2), tl.load(indices + i_t * 2 + 1)

This can create confusion and obscure the kernel’s indexing logic. A clearer naming convention for the chunk index (e.g., chunk_id) would improve maintainability.

Below is a sample diff showing how to rename the overwritten variable:

- i_t, i_bh = tl.program_id(0), tl.program_id(1)
- if USE_OFFSETS:
-     i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
+ chunk_id, i_bh = tl.program_id(0), tl.program_id(1)
+ if USE_OFFSETS:
+     i_n, local_t = tl.load(indices + chunk_id * 2).to(tl.int32), tl.load(indices + chunk_id * 2 + 1).to(tl.int32)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
tests/ops/test_delta.py (1)

166-170: Ensure consistency in error threshold adjustments.
Similar to the earlier block, increasing the thresholds for these asserts can hide small, potentially important deviations. Consider verifying the root cause of any test failures that prompted this change.

If necessary, refine kernels or intermediate calculations to improve numerical precision, rather than only loosening the test checks.

fla/ops/common/chunk_scaled_dot_kkt.py (2)

62-62: Assess potential mixed dtype in the dot product.
The code does b_kb.to(b_k.dtype) prior to tl.dot, which might cause unintended truncations. Consider verifying consistent dtypes or employing higher precision for intermediate sums if numeric stability is a priority.


104-127: Evaluate kernel tiling constants and performance.
The autotuning parameters (num_warps=[2,4,8], num_stages=[2,3,4], BK=64) look reasonable, but large K or T may still benefit from further kernel optimizations or dynamic block sizing.

Do you want to explore a larger range of warp counts or advanced tile shapes for better GPU occupancy?

fla/ops/utils/solve_tril.py (2)

255-274: Consider adding explicit verification of the strictly lower triangular property.

While the docstring states "A should be strictly lower triangular," there's no explicit verification of this requirement in the implementation, which could lead to incorrect results if a non-conforming matrix is provided.

Consider adding a verification step before processing:

def solve_tril(
    A: torch.Tensor,
    cu_seqlens: Optional[torch.Tensor] = None,
    head_first: bool = False,
    output_dtype: torch.dtype = torch.float
) -> torch.Tensor:
    """
    Compute the inverse of the lower triangular matrix
    A should be strictly lower triangular, i.e., A.triu() == 0.
    ...
    """
    assert A.shape[-1] in [16, 32, 64]
    assert A.dtype == torch.float, "A should be float32."
+   # Verify A is strictly lower triangular
+   if head_first:
+       is_tril = torch.all(torch.triu(A, diagonal=1) == 0)
+   else:
+       is_tril = torch.all(torch.triu(A, diagonal=1) == 0)
+   assert is_tril, "Input matrix must be strictly lower triangular (A.triu(diagonal=1) == 0)"

275-276: Consider relaxing dtype requirements for greater flexibility.

Currently, the implementation strictly requires float32 input. Consider supporting other precision types (float16, bfloat16) that might be useful in different contexts, particularly for large models or memory-constrained environments.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c2674c1 and d55608e.

📒 Files selected for processing (4)
  • fla/ops/common/chunk_scaled_dot_kkt.py (1 hunks)
  • fla/ops/delta_rule/wy_fast.py (2 hunks)
  • fla/ops/utils/solve_tril.py (1 hunks)
  • tests/ops/test_delta.py (3 hunks)
🧰 Additional context used
🧬 Code Definitions (1)
fla/ops/delta_rule/wy_fast.py (2)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
  • chunk_scaled_dot_kkt_fwd (72-126)
fla/ops/utils/solve_tril.py (1)
  • solve_tril (249-318)
🔇 Additional comments (14)
tests/ops/test_delta.py (2)

21-21: Acknowledge expanded test coverage with larger sequence length.
By including 512 in the test_t_list, the tests now cover higher sequence lengths, which helps identify performance issues and edge cases. Ensure that necessary GPU memory is available and that test runtime remains acceptable for large T.


90-96: Relaxed error thresholds may mask subtle differences.
Raising the thresholds from 0.005/0.007 to 0.006/0.008 reduces the strictness of the test. Confirm the new thresholds are appropriate and won't overlook potential numerical instability in the kernel.

Would you like to double-check numerical stability across different devices or random seeds to ensure these relaxed thresholds remain reliable?

fla/ops/delta_rule/wy_fast.py (4)

10-11: Good integration with chunked K-K^T and TRIL solver.
These new imports improve code organization by extracting specialized logic into separate modules.


223-230: Validate correct usage of chunk_scaled_dot_kkt_fwd.
The function call is consistent with the docstring, but double-check that calling code supplies correct shapes and ensures data is contiguous if required.

Do you want to run an additional script to confirm the shapes of k and beta across the codebase?


231-236: Verify return dtype alignment in solve_tril.
Returning A in k.dtype after a TRIL solve might introduce subtle precision differences if k.dtype is lower than float32.

Confirm that k is typically float32 or a higher-precision dtype. Alternatively, keep the returned matrix in float32 for stability.


220-220:

Details

❓ Verification inconclusive

Changing head_first default to False.
Previously, the default was True. This can alter the shape assumptions in downstream code. Verify that all external usage correctly accounts for the new default.


🏁 Script executed:

#!/bin/bash
# Search codebase for calls to `fwd_prepare_wy_repr` that rely on the old default.
rg -A 3 "fwd_prepare_wy_repr\(.*\)"

Length of output: 37


Verify Impact of Default Change on head_first

The default value for head_first in fwd_prepare_wy_repr has been changed from True to False in fla/ops/delta_rule/wy_fast.py at line 220. Our automated search for calls to fwd_prepare_wy_repr that might rely on the old default did not return any instances. However, this outcome is inconclusive. Please manually verify that all external usages either explicitly set the head_first parameter or operate correctly under the new default value.

  • File to check: fla/ops/delta_rule/wy_fast.py (line 220)
  • Action Required: Confirm that downstream code and external calls to fwd_prepare_wy_repr correctly handle the change from head_first=True to head_first=False.
fla/ops/common/chunk_scaled_dot_kkt.py (1)

39-47: Variable name overshadowing remains.
Reusing i_t to both denote tl.program_id(0) and the chunk offset may cause confusion.

fla/ops/utils/solve_tril.py (7)

14-71: Well-optimized kernel implementation with clear structure.

The kernel correctly implements an efficient algorithm for 16x16 lower triangular matrix inversion with proper boundary checks and memory access patterns. The autotuning configuration provides good performance optimization opportunities.


73-132: Effective 32x32 matrix inversion using block matrix approach.

The implementation correctly uses block matrix inversion formulas to merge 16x16 inversions into a 32x32 inverse. The boundary checks and precision handling are properly implemented.


134-246: Comprehensive 64x64 matrix inversion with proper block structure.

The implementation effectively handles the complexity of a 64x64 matrix inversion by decomposing it into appropriate subblocks. The matrix multiplication operations and boundary checks are properly sequenced to ensure correct results.


285-296: LGTM: Correct kernel selection and parameterization.

The function correctly prepares chunk indices and calls the solve_tril_16x16_kernel with appropriate parameters for handling variable-length sequences.


304-317: LGTM: Proper kernel selection based on matrix dimensions.

The code correctly selects the appropriate merge function based on the matrix dimensions (32x32 or 64x64) and passes all the necessary parameters.


285-286: Nice use of triton.cdiv for calculating grid dimensions.

The calculation of NT properly handles both variable-length sequences (using indices length) and fixed-length sequences (using triton.cdiv), ensuring efficient kernel grid configuration.


248-318: Excellent orchestration function for different matrix sizes.

The solve_tril function provides a clean interface that encapsulates the complexity of the matrix inversion process. It correctly handles different matrix sizes (16x16, 32x32, 64x64) and memory layouts (head-first or not).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
fla/ops/utils/solve_tril.py (1)

134-246: Comprehensive 64x64 matrix inversion with block decomposition

The implementation correctly extends the block inversion approach to 64x64 matrices, handling the more complex interdependencies between blocks. The use of IEEE precision and proper boundary checks is commendable.

Consider adding comments explaining the mathematical relationships between blocks to improve maintainability, as the block operations are quite dense and complex.

fla/ops/common/chunk_scaled_dot_kkt.py (1)

66-66: Consider optimizing out the upper-triangle computation.

Currently, the entire b_A block is computed, and then the upper triangular part is zeroed out. If only the lower triangle is of interest, compute it directly to improve efficiency or reduce unnecessary work.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d55608e and b00e556.

📒 Files selected for processing (2)
  • fla/ops/common/chunk_scaled_dot_kkt.py (1 hunks)
  • fla/ops/utils/solve_tril.py (1 hunks)
🔇 Additional comments (7)
fla/ops/utils/solve_tril.py (5)

14-71: Well-implemented Triton kernel for 16x16 matrix inversion

The implementation of solve_tril_16x16_kernel is robust, with proper boundary checks and memory layout handling. The kernel efficiently computes the inverse of a strictly lower triangular matrix, building the inverse iteratively and adding identity at the end.


73-132: Effective implementation of 32x32 matrix inversion via block merging

This kernel correctly implements the block matrix inversion technique by combining 16x16 inverses into a 32x32 inverse. The use of tl.dot with explicit input_precision='ieee' ensures numerical stability, which is crucial for matrix inversions.


248-318: Well-orchestrated function with clear interface

The solve_tril function provides a clean interface to the underlying kernels with proper input validation, output type handling, and support for different layouts and variable sequence lengths.


286-286: Minor style issue with spacing

There's a missing space after the comma in the triton.cdiv(T,16) call.

-    NT = len(indices) if cu_seqlens is not None else triton.cdiv(T,16)
+    NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, 16)

275-276: Clear input validation with assertions

The assertions effectively validate that the input matrix has a supported size (16, 32, or 64) and the correct data type (float32). This prevents unexpected behavior and provides clear error messages.

fla/ops/common/chunk_scaled_dot_kkt.py (2)

40-43: Avoid overshadowing i_t.

Reassigning i_t after reading it from tl.program_id(0) can cause confusion and reduces the clarity of the kernel’s indexing logic. Renaming the first usage (or the second) improves maintainability and aligns with prior review feedback.


63-64: Double-check precision when casting b_kb.

Casting b_beta to b_k.dtype before the dot product might introduce floating-point precision loss if b_k is a lower-precision format than beta. Ensure the desired numeric accuracy is maintained, potentially by upcasting both to a higher precision (e.g., float32) before the multiplication.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
fla/ops/utils/solve_tril.py (1)

277-278: Consider supporting multiple precision types beyond float32.

The function currently restricts input to float32 only. Consider supporting other precision types (like float16) for better flexibility, especially in performance-critical applications where reduced precision might be acceptable.

-assert A.dtype == torch.float, "A should be float32."
+assert A.dtype in [torch.float, torch.float16], f"A should be float32 or float16, got {A.dtype}."
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b00e556 and b08bb88.

📒 Files selected for processing (2)
  • fla/ops/utils/solve_tril.py (1 hunks)
  • tests/ops/test_delta.py (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/ops/test_delta.py
⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (7)
fla/ops/utils/solve_tril.py (7)

14-71: Well-implemented kernel for inverting 16x16 lower triangular matrices with thorough optimizations.

The kernel implementation effectively uses Triton's autotuning capabilities to optimize performance across different workloads. The code properly handles boundary checking, memory layouts, and different data types. The mathematical approach for matrix inversion is sound.


73-133: Efficient implementation of 32x32 lower triangular matrix inversion using block-wise approach.

This kernel effectively merges the results from 16x16 inversions to compute a 32x32 inverse. The implementation follows standard mathematical techniques for block matrix inversion, with proper boundary checking and memory handling.


135-248: Comprehensive implementation of 64x64 lower triangular matrix inversion with proper mathematical approach.

The kernel handles the complex case of merging multiple 16x16 results into a 64x64 inverse. Despite the complexity, the implementation is mathematically sound and includes proper memory handling and boundary checks.


257-276: Clear and informative docstring.

The docstring clearly explains the function's purpose, parameters, and return values. It provides necessary information for users to understand how to properly use the function.


288-288: Use consistent spacing around operators.

The spacing in this line is inconsistent with the style used in the rest of the file.

-NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, 16)
+NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, 16)

250-321: Well-designed orchestration function with good error handling.

The solve_tril function effectively orchestrates the appropriate kernels based on matrix size and handles different memory layouts and sequence lengths correctly. The error handling through assertions helps catch potential issues early.


1-322:

Details

❓ Verification inconclusive

Verify numerical stability with test cases.

While the implementation is mathematically sound, it's crucial to verify numerical stability across different matrix sizes and conditions, especially for edge cases.


🏁 Script executed:

#!/bin/bash
# Check if there are test cases for this new implementation

# Look for test files related to solve_tril
rg -l "solve_tril" --type py tests/

Length of output: 64


Ensure Numerical Stability Through Comprehensive Test Coverage

The presence of tests/ops/test_solve_tril.py is a good sign. Please double-check that this file covers a diverse set of scenarios, including various matrix sizes (i.e. for BT values of 16, 32, and 64) and potential edge cases (e.g. minimal sizes or atypical input values) to confirm numerical stability across all conditions.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
.github/workflows/nvidia-h100.yml (2)

50-50: Verify the MAX_JOBS configuration for flash-attn build

The command now installs flash-attn with MAX_JOBS=4 and the --no-build-isolation flag. This setting appears intended to control parallel build jobs on CI runners. Please verify that 4 is optimal for the available CI resources or consider parameterizing this value if needed in different environments.


51-51: Simplify triton installation command

The updated command pip install -U triton removes the previously specified index URL and trusted host parameters. This simplification is acceptable provided that the default PyPI repository serves the appropriate Triton builds. It’s worthwhile to double-check that this change aligns with expected package sourcing requirements.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b08bb88 and 527d742.

📒 Files selected for processing (2)
  • .github/workflows/nvidia-4090.yml (1 hunks)
  • .github/workflows/nvidia-h100.yml (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • .github/workflows/nvidia-4090.yml
⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (1)
.github/workflows/nvidia-h100.yml (1)

47-48: Streamlined pip installation commands for core packages

The change to use the short -U flag (instead of --upgrade) for installing packages like pytest, setuptools, wheel, ninja, and torch simplifies the command while preserving the intended functionality and ensuring that the latest versions are installed. This improvement increases readability and maintains consistency with similar workflows.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/ops/test_dplr_delta.py (2)

317-320: Platform-specific test skip is well-implemented.

This skip condition appropriately excludes the test on Intel platforms where Triton kernel execution is problematic. The reason message clearly explains why the test is being skipped.

Consider adding a link to an issue tracker or more details about the specific failure mode to help future maintainers understand the limitation better.


409-412: Consistent platform-specific test skipping.

This skip condition correctly matches the one applied to the test_chunk function, ensuring consistent treatment of Intel platforms across similar test cases.

To reduce duplication, consider creating a reusable marker or helper for Intel platform skipping if more tests need similar treatment in the future:

intel_skip = pytest.mark.skipif(
    device_platform == 'intel',
    reason="Intel Triton Failure"
)

# Then use it like:
@intel_skip
def test_chunk(...):
    ...
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1446e9e and a8aa59e.

📒 Files selected for processing (1)
  • tests/ops/test_dplr_delta.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (1)
tests/ops/test_dplr_delta.py (1)

12-12: Import addition for platform-specific test handling.

The import of device_platform is correctly added to support the new platform-specific test skip conditions later in the file.

@yzhangcs yzhangcs merged commit 429920a into main Apr 3, 2025
3 of 6 checks passed
@yzhangcs yzhangcs deleted the faster-deltanet branch April 3, 2025 17:06
@coderabbitai coderabbitai bot mentioned this pull request Jan 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants