Skip to content

[KDA] Speed up chunk_kda by introducing lowerbound gate#703

Merged
zhiyuan1i merged 10 commits intomainfrom
msh/lowerbound-kda
Dec 30, 2025
Merged

[KDA] Speed up chunk_kda by introducing lowerbound gate#703
zhiyuan1i merged 10 commits intomainfrom
msh/lowerbound-kda

Conversation

@zhiyuan1i
Copy link
Copy Markdown
Collaborator

@zhiyuan1i zhiyuan1i commented Dec 30, 2025

After profiling, we found that the original chunk_kda kernel is heavily bound by CUDA-core computation. To maintain numerical stability along the diagonal of the (causal-mask) attention matrix we had to split each 64-element chunk a second time—running a 16-element tile on Tensor Cores and looping over the remainder with CUDA cores. This introduced enormous overhead: in Triton those loops are hard to optimize, almost impossible to overlap, and, given the fixed FLOP/byte ratio of the hardware, inevitably slower.

By switching the activation function we enabled pure Tensor-Core execution with M = 16, which delivered a large end-to-end speed-up. To regain extra precision—and to keep the option of M = 32 Tensor Cores—we added a median shift to the log-domain cumulative-sum, reducing the accuracy loss at M = 16 while still allowing M = 32. Further Triton tweaks (e.g., transposing with M = 64 Tensor Cores) have so far yielded only marginal gains. At this point chunk_kda is a purely memory-bound kernel.

Summary by CodeRabbit

  • New Features

    • Added runtime controls for gating and recompute (safe_gate, disable_recompute, lower_bound, cu_seqlens) and exposed intermediate outputs to improve memory/recompute handling.
  • Bug Fixes

    • Improved numeric stability and masking for variable-length sequences; standardized explicit float32 casts and TF32 precision handling; added explicit kernel barrier for correctness.
  • Tests

    • Expanded tests to cover new gating/recompute options and kernel vs. naive gating paths.
  • Chores

    • Benchmarks updated to exercise the new gating option.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 30, 2025

Caution

Review failed

The pull request is closed.

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a lower-bound-aware chunked gating path and safe_gate flag, threads safe_gate/lower_bound/disable_recompute through forward/backward, expands saved intermediates (w, u, qg, kg, v_new, h), introduces an intra sub-chunk kernel with gather/varlen branches, and updates recompute/WY wiring and tests.

Changes

Cohort / File(s) Summary
Core KDA chunk API
fla/ops/kda/chunk.py
Extended public APIs (chunk_kda, chunk_kda_fwd, chunk_kda_bwd, ChunkKDAFunction) to accept cu_seqlens_cpu, safe_gate, lower_bound, disable_recompute; use prepare_chunk_indices; switch gate import to kda_gate_chunk_cumsum; save/propagate intermediates (w,u,qg,kg,v_new,h); updated forward/backward wiring and memory handling.
Intra kernel & token-parallel
fla/ops/kda/chunk_intra.py, fla/ops/kda/chunk_intra_token_parallel.py
Added chunk_kda_fwd_kernel_intra_sub_chunk; introduced USE_SAFE_GATE/SAFE_GATE branches, gather-aware and varlen branches (IS_GATHER_SUPPORTED, IS_VARLEN); chunk_kda_fwd_intra now accepts safe_gate/disable_recompute and returns (w,u,qg,kg,Aqk,Akk); TF32 precision flag adjusted.
WY recompute & kernels
fla/ops/kda/wy_fast.py
Renamed kernels to KDA-specific variants; recompute_w_u_fwd now returns (w,u,qg,kg) (4-tuple); DOT_PRECISION autotune options extended to include "tf32"; call sites updated.
Backward / kernel precision tweaks
fla/ops/kda/chunk_bwd.py, fla/ops/common/chunk_delta_h.py, fla/ops/gla/chunk.py
Added explicit .to(tl.float32) casts for several Triton loads and inserted an explicit tl.debug_barrier() comment in a fused kernel. No public API changes.
Gate util
fla/ops/kda/gate.py
Minor header comment edit; gating usage switched to kda_gate_chunk_cumsum and lower_bound is threaded through gating paths.
Tests & Benchmarks
tests/ops/test_kda.py, benchmarks/ops/benchmark_kda.py
Tests updated to parameterize safe_gate/disable_recompute (and use_gate_in_kernel for varlen); kernel vs naive gating paths apply lower_bound; benchmark calls chunk_kda(..., safe_gate=True).
Utilities import
fla/ops/utils/* (used)
prepare_chunk_indices imported/used for chunk index computation when cu_seqlens provided.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant ChunkFunc as ChunkKDAFunction
    participant Indexer as prepare_chunk_indices
    participant Gate as kda_gate_chunk_cumsum
    participant Intra as chunk_kda_fwd_intra
    participant WY as recompute_w_u_fwd_kda
    participant Mem as MemoryStore

    Client->>ChunkFunc: forward(q,k,v,..., safe_gate, lower_bound, disable_recompute)
    alt cu_seqlens provided
        ChunkFunc->>Indexer: prepare_chunk_indices(cu_seqlens)
        Indexer-->>ChunkFunc: chunk_indices
    end
    alt safe_gate enabled
        ChunkFunc->>Gate: compute gated A/g (lower_bound, scale)
        Gate-->>ChunkFunc: gate outputs (A_log, g, dt_bias, ...)
    end
    ChunkFunc->>Intra: chunk_kda_fwd_intra(..., safe_gate, disable_recompute, chunk_indices)
    Intra-->>ChunkFunc: (w,u,qg,kg,Aqk,Akk,v_new,h)
    alt recompute allowed
        ChunkFunc->>WY: recompute_w_u_fwd_kda(...)
        WY-->>ChunkFunc: (w,u,qg,kg)
    else
        ChunkFunc->>Mem: preserve intermediates for backward
    end
    ChunkFunc->>Mem: save state (w,u,qg,kg,v_new,h,safe_gate,lower_bound,disable_recompute)
    Client->>ChunkFunc: backward(grad_out)
    ChunkFunc->>ChunkFunc: restore or recompute intermediates
    alt safe_gate enabled
        ChunkFunc->>Gate: backward through gate (uses saved/recomputed intermediates)
    end
    ChunkFunc-->>Client: gradients (dq,dk,dv,dA_log,ddt_bias)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • PR #643: Modifies intra-kernel implementation in fla/ops/kda/chunk_intra.py — overlaps with intra kernel and sub-chunk changes.
  • PR #701: Implements lower-bound-aware chunked gating and safe_gate flags wired through chunk forward/backward — directly related to gating changes here.
  • PR #689: Updates KDA chunk backward flow and intermediate wiring (wy/dqkg fusion) — intersects with backward/intermediate propagation changes.

Suggested reviewers

  • yzhangcs
  • Nathancgy

"🐰
I hop through gates both safe and sound,
Sub-chunks gather where new tensors are found.
Recompute tiptoes or snores in the night,
Backward uncovers what forward held tight.
Thump-thump — chunked attention takes flight!"

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.54% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main objective of the PR: introducing a lowerbound gate optimization to speed up chunk_kda, which aligns with the substantive changes across multiple kernel files and API signatures.

📜 Recent review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8d1659f and 990535b.

📒 Files selected for processing (1)
  • fla/ops/common/chunk_delta_h.py

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the chunk_kda operation, a key component in the KDA (Kernelized Deep Attention) architecture. The primary goal is to enhance performance and numerical stability through the introduction of a 'lowerbound gate' mechanism and refined recomputation strategies. It also includes adaptations for different Triton compiler versions to ensure optimal performance across environments. These changes collectively aim to make the chunk_kda operation more efficient and robust.

Highlights

  • Performance Optimization: The chunk_kda operation has been significantly sped up by introducing a 'lowerbound gate' mechanism, which enhances numerical stability and efficiency.
  • Numerical Stability and Control: New safe_gate and lower_bound parameters have been added to control the range of gate values, improving the numerical stability of computations within the KDA framework.
  • Memory and Computation Management: A disable_recompute flag has been introduced, allowing the system to pass intermediate values instead of recomputing them during the backward pass, offering a trade-off between memory usage and computational cost.
  • Triton Version Adaptability: Conditional logic based on TRITON_ABOVE_3_4_0 has been implemented, along with a new backward kernel (chunk_kda_bwd_dqkwg), to optimize performance across different Triton compiler versions.
  • Code Refactoring and Modularity: Extensive refactoring has been performed in chunk_intra.py, including the introduction of new kernels like chunk_kda_fwd_kernel_intra_sub_chunk and conditional execution paths, leading to improved code modularity and potential performance gains.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimizations to chunk_kda by adding a safe_gate option and a disable_recompute flag to trade memory for speed. The changes are extensive, touching the forward and backward passes, and introducing new Triton kernels and fallback paths for older Triton versions.

My review has identified a critical issue where a dimension assertion was removed while the underlying kernels still have the limitation, which could lead to runtime errors. I've also pointed out a minor style inconsistency and a potential precision degradation that comes with one of the optimizations.

Overall, the changes are substantial and aim to improve performance, but the critical issue should be addressed.

Comment thread fla/ops/kda/chunk.py
Comment thread fla/ops/kda/chunk.py Outdated

if IS_TF32_SUPPORTED:
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3')
SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing SOLVE_TRIL_DOT_PRECISION from 'tf32x3' to 'tf32' may improve performance but can reduce numerical precision. This seems to be reflected in the increased error tolerances in the associated tests. This precision-performance trade-off should be documented with a code comment if it is intentional.

@zhiyuan1i zhiyuan1i force-pushed the msh/lowerbound-kda branch 2 times, most recently from 29027fc to c4a9f93 Compare December 30, 2025 01:33
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tests/ops/test_kda.py (2)

311-313: Consider documenting the safe_gate restriction.

The assertion assert use_gate_in_kernel is False when safe_gate is True restricts safe_gate testing to non-kernel gate paths in varlen tests. If this is intentional (e.g., safe_gate with kernel gating not yet implemented for varlen), consider adding a comment or pytest.skip with a reason.

Looking at the test cases (lines 273-277), safe_gate=True only appears with use_gate_in_kernel=False, so this is consistent—but explicit documentation would help future maintainers.


355-356: Verify reference function selection for varlen tests with safe_gate.

The varlen test uses naive_kda_gate regardless of safe_gate value, while test_chunk (line 200) selects between naive_kda_gate and naive_kda_lowerbound_gate based on safe_gate.

Given that safe_gate test cases in varlen all have use_gate_in_kernel=False (lines 273-277), the gate function selection here may not matter since the kernel gate path isn't exercised. However, for consistency and future-proofing, consider using the same conditional selection pattern.

fla/ops/kda/chunk.py (1)

435-437: Improve docstring for cu_seqlens_cpu.

The docstring for cu_seqlens_cpu duplicates cu_seqlens's description. Consider clarifying its purpose—it appears to be a CPU-side copy used to avoid GPU synchronization when preparing chunk indices.

🔎 Suggested improvement
         cu_seqlens_cpu (torch.LongTensor):
-            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
-            consistent with the FlashAttention API.
+            CPU-resident copy of `cu_seqlens` of shape `[N+1]`.
+            When provided, avoids GPU-to-CPU synchronization during chunk index preparation.
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 455360b and 87649fb.

📒 Files selected for processing (8)
  • fla/ops/kda/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_inter.py
  • fla/ops/kda/chunk_intra.py
  • fla/ops/kda/chunk_intra_token_parallel.py
  • fla/ops/kda/gate.py
  • fla/ops/kda/wy_fast.py
  • tests/ops/test_kda.py
🧰 Additional context used
🧬 Code graph analysis (3)
fla/ops/kda/chunk_inter.py (4)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (112-122)
fla/utils.py (1)
  • check_shared_mem (454-460)
fla/ops/kda/gate.py (2)
  • grid (215-216)
  • grid (439-439)
fla/ops/common/chunk_o.py (1)
  • grid (506-506)
fla/ops/kda/chunk.py (7)
fla/ops/common/chunk_delta_h.py (1)
  • chunk_gated_delta_rule_bwd_dhu (517-572)
fla/ops/kda/chunk_bwd.py (1)
  • chunk_kda_bwd_dAv (275-323)
fla/ops/kda/chunk_inter.py (1)
  • chunk_kda_bwd_dqkwg (142-191)
fla/ops/kda/chunk_intra.py (2)
  • chunk_kda_bwd_intra (836-900)
  • chunk_kda_fwd_intra (739-833)
fla/ops/kda/gate.py (1)
  • kda_gate_chunk_cumsum (416-456)
fla/ops/kda/wy_fast.py (1)
  • recompute_w_u_fwd (212-256)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (112-122)
tests/ops/test_kda.py (2)
fla/ops/kda/gate.py (3)
  • naive_kda_lowerbound_gate (50-62)
  • naive_kda_gate (19-47)
  • backward (301-310)
fla/ops/kda/chunk.py (1)
  • backward (307-379)
🔇 Additional comments (29)
fla/ops/kda/gate.py (1)

2-2: LGTM!

Minor header comment update with no functional impact.

fla/ops/kda/chunk_bwd.py (1)

220-220: LGTM!

Adding a comment to document the critical synchronization barrier is a good practice for maintainability.

fla/ops/kda/chunk_intra_token_parallel.py (2)

106-108: LGTM!

Separating the computation from masking improves code clarity while maintaining functional equivalence.


111-114: Verify the index calculation change.

The storage index changed from j % BC to j - i_ts. Given that j iterates from i_ts to at most i_ts + BC - 1, both expressions should yield values in [0, BC). The new form j - i_ts is more explicit about the offset from the sub-chunk start.

Please verify this change maintains correctness for edge cases where sequences don't align to chunk boundaries.

fla/ops/kda/chunk_inter.py (3)

1-14: LGTM!

Standard copyright header, imports, and block size configuration consistent with other KDA modules.


31-139: Inter-chunk backward kernel implementation looks correct.

The kernel correctly:

  • Handles variable-length sequences via IS_VARLEN
  • Computes gradient contributions dq, dk, dw, dg using the chain rule
  • Applies gating via exp2 for numerical stability
  • Uses block pointers with boundary checks

Minor observation: m_t is computed at line 71 and redefined at line 124. While the second definition overwrites with identical logic, consider removing the redundant recomputation for clarity.


186-189: Bug: Incorrect parameter name in kernel call.

The kernel parameter is named K (line 49), but the call passes D=K. This will cause a parameter mismatch.

🔎 Proposed fix
         T=T,
         H=H,
-        D=K,
+        K=K,
         V=V,
         BT=BT,

Likely an incorrect or invalid review comment.

tests/ops/test_kda.py (2)

150-167: LGTM!

Test parameterization properly extended to cover new safe_gate and disable_recompute flags with good coverage of different combinations.


196-203: LGTM!

The safe_gate logic correctly:

  • Sets lower_bound when enabled
  • Clamps gate values for non-kernel gate path
  • Selects appropriate naive gate function for reference comparison
fla/ops/kda/wy_fast.py (4)

22-22: LGTM!

Adding tf32x3 precision option provides additional autotuning flexibility for TF32-capable hardware.


28-28: LGTM!

Kernel rename to recompute_w_u_fwd_kda_kernel improves namespacing and distinguishes KDA-specific kernels.


116-116: LGTM!

Kernel rename to prepare_wy_repr_bwd_kda_kernel maintains consistency with the forward kernel naming.


221-256: LGTM!

The expanded return type tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None] correctly reflects the addition of optional qg and kg outputs. The implementation properly handles the conditional allocation based on input parameters.

fla/ops/kda/chunk.py (7)

10-17: LGTM!

Import additions properly bring in the new modules needed for the extended KDA functionality.


32-47: LGTM!

Forward function signature and parameter propagation correctly extended for safe_gate and disable_recompute support.


73-77: LGTM!

Memory optimization pattern correctly implemented—intermediate tensors are set to None when disable_recompute=False to allow garbage collection, while preserving them when recomputation is disabled.


99-122: LGTM!

The conditional recomputation logic correctly handles both paths:

  • When disable_recompute=False: recomputes w, u, qg, kg, h, v_new
  • When disable_recompute=True: retrieves cached values from **kwargs

This provides flexibility for memory vs compute trade-offs.


151-197: LGTM!

The Triton version conditional properly gates between:

  • TRITON_ABOVE_3_4_0: Uses fused chunk_kda_bwd_wy_dqkg_fused kernel
  • Older versions: Falls back to separate chunk_kda_bwd_dqkwg + prepare_wy_repr_bwd

This ensures compatibility across Triton versions while enabling optimizations on newer versions.


288-293: LGTM!

Correctly nullifies g when it will be recomputed in backward (when disable_recompute=False and use_gate_in_kernel=True), saving memory while preserving g_org for recomputation.


370-377: Duplicated cumsum call in else branch.

The chunk_local_cumsum call in the else branch (lines 371-377) is identical to the call that occurs unconditionally before the if ctx.use_gate_in_kernel block (lines 353-359). This means when use_gate_in_kernel=False, the cumsum is applied twice.

Looking more carefully at lines 352-377: the chunk_local_cumsum at lines 353-359 is inside the if ctx.use_gate_in_kernel block. So the else branch is correct—it handles the non-kernel gate path. The indentation may be confusing but the logic is correct.

fla/ops/kda/chunk_intra.py (9)

9-16: LGTM!

Import updates and SOLVE_TRIL_DOT_PRECISION constant properly set up TF32-aware precision control for matrix operations.


54-54: LGTM!

Adding USE_SAFE_GATE parameter enables conditional kernel behavior for the safe-gate optimization path.


143-195: LGTM!

The nested conditionals for off-diagonal block computation are structurally sound. Each level (i_tc2 < T, i_tc3 < T) properly guards its computations, and the accumulated values (b_Aqk20, b_Akk20, etc.) are correctly computed using the gated exponential differences.


244-277: Conditional forward substitution for safe_gate path.

When USE_SAFE_GATE=True, the forward substitution is skipped here because the chunk_kda_fwd_kernel_intra_sub_chunk kernel (lines 636-736) handles it. This is a valid optimization that avoids redundant computation.


464-504: LGTM!

The SAFE_GATE path for diagonal block handling uses batched matrix operations instead of element-wise loops:

  • Loads full diagonal blocks of dAqk and dAkk
  • Applies proper masking (m_i_diag_qk, m_j_diag_qk) for numerical stability
  • Computes gradients via matrix operations

This approach improves GPU utilization compared to the scalar loop in the else branch.


560-609: LGTM!

The SAFE_GATE path for k-transpose computation mirrors the structure of the qk-diagonal path:

  • Uses gather or direct load based on USE_GATHER capability
  • Applies proper masking for boundary handling
  • Maintains numerical stability through the gating exponential differences

623-736: New sub-chunk kernel for safe_gate forward path.

The chunk_kda_fwd_kernel_intra_sub_chunk kernel implements the intra-chunk computation with improved numerical stability:

  • Uses a reference point b_gn at sub-chunk midpoint for stable exponential computation
  • Computes Aqk and Akk with proper masking
  • Includes forward substitution for matrix inverse

One observation: The forward substitution loop (lines 730-734) modifies b_Ai iteratively using scalar loads from Akk. Ensure the boundary condition min(BC, T - i_ti) correctly handles edge cases where the chunk extends beyond sequence length.


768-801: LGTM!

The conditional dispatch correctly routes to:

  • chunk_kda_fwd_kernel_intra_sub_chunk when safe_gate=True
  • chunk_kda_fwd_intra_token_parallel when safe_gate=False

This enables the safe-gate optimization path while maintaining backward compatibility.


823-833: LGTM!

Return signature correctly updated to (w, u, qg, kg, Aqk, Akk) to include the new qg output from recompute_w_u_fwd.

Comment thread fla/ops/kda/chunk.py
zhiyuan1i and others added 2 commits December 30, 2025 01:36
Co-authored-by: Wang Yucheng <wangyucheng@moonshot.cn>
Co-authored-by: yzhangcs <zhangyu02@moonshot.cn>
Co-authored-by: Chen Guangyu <chenguangyu@moonshot.cn>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
fla/ops/kda/chunk.py (1)

217-218: Fix typo in assertion message (duplicate issue).

The assertion message says "sage_gate" but should say "safe_gate". This was flagged in a previous review.

🧹 Nitpick comments (2)
fla/ops/kda/chunk.py (2)

257-258: Use idiomatic boolean comparison.

Per PEP 8, prefer if not disable_recompute and use_gate_in_kernel: over if disable_recompute is False and use_gate_in_kernel:.

🔎 Proposed fix
-        if disable_recompute is False and use_gate_in_kernel:
+        if not disable_recompute and use_gate_in_kernel:

286-286: Use idiomatic boolean comparison.

Same as above - prefer if not ctx.disable_recompute and ctx.use_gate_in_kernel:.

🔎 Proposed fix
-        if ctx.disable_recompute is False and ctx.use_gate_in_kernel:
+        if not ctx.disable_recompute and ctx.use_gate_in_kernel:
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 87649fb and c4a9f93.

📒 Files selected for processing (7)
  • fla/ops/kda/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_intra.py
  • fla/ops/kda/chunk_intra_token_parallel.py
  • fla/ops/kda/gate.py
  • fla/ops/kda/wy_fast.py
  • tests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • fla/ops/kda/gate.py
  • fla/ops/kda/chunk_bwd.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.

Applied to files:

  • fla/ops/kda/chunk_intra.py
🧬 Code graph analysis (3)
tests/ops/test_kda.py (2)
fla/ops/kda/gate.py (3)
  • naive_kda_lowerbound_gate (50-62)
  • naive_kda_gate (19-47)
  • backward (301-310)
fla/ops/kda/chunk.py (1)
  • backward (276-348)
fla/ops/kda/chunk_intra.py (4)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (112-122)
fla/ops/utils/op.py (1)
  • gather (25-31)
fla/ops/kda/chunk_intra_token_parallel.py (2)
  • grid (152-152)
  • chunk_kda_fwd_intra_token_parallel (117-169)
fla/ops/kda/wy_fast.py (1)
  • recompute_w_u_fwd (212-256)
fla/ops/kda/chunk.py (5)
fla/ops/kda/gate.py (2)
  • kda_gate_bwd (234-272)
  • kda_gate_chunk_cumsum (416-456)
fla/ops/kda/wy_fast.py (1)
  • recompute_w_u_fwd (212-256)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (112-122)
fla/ops/kda/chunk_intra.py (1)
  • chunk_kda_fwd_intra (739-833)
fla/ops/common/chunk_delta_h.py (1)
  • chunk_gated_delta_rule_fwd_h (464-514)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (21)
fla/ops/kda/chunk_intra_token_parallel.py (1)

106-114: LGTM - Refactored masking and indexing logic.

The changes correctly:

  1. Apply the mask post-multiply via tl.where(m_k[None, :], b_kgj, 0.0) instead of inline
  2. Use multiplicative mask * tl.where(j < i_t, 1.0, 0.0) for b_Akk computation
  3. Change index from j % BC to j - i_ts which is equivalent since j ranges from i_ts to i_ts + BC - 1

These are functionally equivalent transformations that align with the new gating/disable-recompute flow.

fla/ops/kda/wy_fast.py (3)

22-22: LGTM - Extended DOT_PRECISION autotune choices.

Adding "tf32" provides a useful middle-ground between "tf32x3" (higher precision, slower) and "ieee" (full precision, slowest), allowing the autotuner to find better performance/precision tradeoffs.


28-28: LGTM - Kernel renames for KDA-specific variants.

Renaming to recompute_w_u_fwd_kda_kernel and prepare_wy_repr_bwd_kda_kernel improves clarity by distinguishing these KDA-specific kernels from potential generic variants.

Also applies to: 116-116


221-256: LGTM - Extended return values for disable_recompute support.

The function now returns (w, u, qg, kg) where qg and kg are optionally computed based on whether q and gk are provided. This aligns with the disable_recompute flow where these intermediates need to be saved during forward pass to avoid recomputation in backward.

tests/ops/test_kda.py (4)

150-168: LGTM - Extended test parameterization for gating controls.

The test matrix now covers combinations of safe_gate and disable_recompute flags, providing good coverage of the new code paths.


196-203: LGTM - Conditional gating logic for safe_gate.

When safe_gate=True, the test correctly:

  1. Sets lower_bound = -5.0
  2. Clamps g to [-5, 0] when not using gate in kernel
  3. Selects naive_kda_lowerbound_gate as the reference function

263-264: Verify tolerance changes for gradient assertions.

The tolerances for dA (0.003) and dbias (0.008) appear tighter than before. Ensure these are appropriate for the precision characteristics of the new gating paths.


355-356: Reference uses naive_kda_gate but tests with use_gate_in_kernel.

When use_gate_in_kernel=True in varlen tests, the reference path uses naive_kda_gate but the test doesn't account for safe_gate in the reference computation. Since varlen tests with safe_gate=True require use_gate_in_kernel=False, this is currently safe, but the logic could be clearer.

fla/ops/kda/chunk.py (4)

31-46: LGTM - Extended forward signature and data flow.

The forward function correctly:

  1. Accepts new safe_gate and disable_recompute parameters
  2. Passes them through to chunk_kda_fwd_intra
  3. Receives extended return values (w, u, qg, kg, Aqk, Akk)

72-76: LGTM - Conditional memory cleanup.

When disable_recompute=False, intermediates are set to None to free memory since they'll be recomputed in backward. When disable_recompute=True, they're preserved for use in backward.


94-121: LGTM - Backward supports both recompute and cached paths.

The backward function now:

  1. Recomputes w, u, qg, kg, h, v_new when disable_recompute=False
  2. Uses cached values from kwargs when disable_recompute=True

This provides flexibility between memory usage (recompute) and speed (cached).


347-348: Verify return tuple length matches forward parameters.

The backward returns 17 values (including Nones). Verify this matches the forward's positional arguments count: q, k, v, g, beta, A_log, dt_bias, scale, initial_state, output_final_state, use_qk_l2norm_in_kernel, use_gate_in_kernel, cu_seqlens, cu_seqlens_cpu, safe_gate, lower_bound, disable_recompute = 17 parameters. This looks correct.

fla/ops/kda/chunk_intra.py (9)

9-16: LGTM - Updated imports and precision configuration.

The imports are correctly updated to include gather, IS_GATHER_SUPPORTED, and prepare_chunk_indices. The SOLVE_TRIL_DOT_PRECISION change from 'tf32x3' to 'tf32' was noted in a previous review as an intentional precision-performance tradeoff.


54-54: LGTM - Added USE_SAFE_GATE parameter.

The USE_SAFE_GATE constexpr controls whether to skip forward substitution when the sub-chunk kernel already handles the diagonal block inversion.


143-195: LGTM - Restructured off-diagonal block computation.

The nested if-blocks for i_tc2 and i_tc3 correctly compute off-diagonal Aqk and Akk blocks. The structure change improves code organization by grouping related computations together.


244-277: LGTM - Conditional forward substitution.

When USE_SAFE_GATE=False, the standard forward substitution is performed to compute the inverse. When USE_SAFE_GATE=True, this is skipped because the chunk_kda_fwd_kernel_intra_sub_chunk kernel already computes the inverse directly.


464-504: Verify SAFE_GATE gradient computation correctness.

The SAFE_GATE branch uses a block-based approach with masks m_i_diag_qk and m_j_diag_qk instead of the element-by-element loop. The computation:

  1. Loads diagonal dAqk/dAkk blocks
  2. Applies masks for valid positions
  3. Computes exp2(b_g - b_gn) and exp2(-(b_g - b_gn)) for numerical stability
  4. Accumulates gradients via tl.dot

This should be mathematically equivalent to the loop-based approach but more efficient.


623-737: LGTM - New intra sub-chunk kernel for safe gating.

The chunk_kda_fwd_kernel_intra_sub_chunk kernel:

  1. Computes Aqk and Akk for diagonal blocks using gating with numerical stability (subtracting b_gn from middle of block)
  2. Performs forward substitution inline to compute the inverse
  3. Uses gather when supported, otherwise falls back to offset-based loading

The kernel correctly handles:

  • Boundary checks via m_c mask
  • Lower triangular masking via m_Aqk and m_Akk
  • Identity addition for the inverse (b_Ai += m_I)

768-801: LGTM - Conditional kernel dispatch based on safe_gate.

When safe_gate=True, uses the new chunk_kda_fwd_kernel_intra_sub_chunk kernel. Otherwise, uses the existing chunk_kda_fwd_intra_token_parallel path. This provides flexibility without breaking existing functionality.


823-833: LGTM - Updated return signature.

The function now returns (w, u, qg, kg, Aqk, Akk) matching the updated caller expectations. The q parameter to recompute_w_u_fwd is conditionally passed based on disable_recompute to compute qg only when needed.


690-696: No action needed. The code correctly branches on USE_GATHER before calling gather(). When IS_GATHER_SUPPORTED=False, USE_GATHER is False, so the else branch executes with tl.load() and gather() is never called. The fallback gather() returning None is not executed in practice—it exists only to satisfy the Triton compiler. The logic is correct.

Likely an incorrect or invalid review comment.

Comment thread tests/ops/test_kda.py
Comment on lines +311 to +313
if safe_gate:
assert use_gate_in_kernel is False
g = g.clamp(-5, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's examine the test file around the mentioned lines
echo "=== Examining lines 273-313 in test_kda.py ==="
sed -n '273,313p' tests/ops/test_kda.py

echo -e "\n=== Searching for safe_gate=True in test_kda.py ==="
rg -n "safe_gate.*True" tests/ops/test_kda.py -A3 -B3

Repository: fla-org/flash-linear-attention

Length of output: 1814


🏁 Script executed:

#!/bin/bash
# Search for documentation or comments about safe_gate constraint
echo "=== Searching for comments about safe_gate or use_gate_in_kernel constraint ==="
rg -n "safe_gate|use_gate_in_kernel" tests/ops/test_kda.py -B2 | head -50

echo -e "\n=== Checking for docstring in test_chunk_varlen function ==="
sed -n '281,295p' tests/ops/test_kda.py

echo -e "\n=== Searching for safe_gate in other files ==="
rg -l "safe_gate" --type py | head -10

Repository: fla-org/flash-linear-attention

Length of output: 2478


🏁 Script executed:

#!/bin/bash
# Look at the constraint in the first test function (test_chunk) around line 196-198
echo "=== Context in test_chunk function around safe_gate constraint ==="
sed -n '190,205p' tests/ops/test_kda.py

echo -e "\n=== Checking implementation files for safe_gate constraint documentation ==="
rg -n "safe_gate" fla/ops/kda/chunk.py -B3 -A3 | head -50

echo -e "\n=== Checking chunk_intra.py ==="
rg -n "safe_gate" fla/ops/kda/chunk_intra.py -B2 -A2 | head -30

Repository: fla-org/flash-linear-attention

Length of output: 3349


🏁 Script executed:

#!/bin/bash
# Find test_chunk parameters and check for safe_gate=True cases
echo "=== Extracting test_chunk parameters (lines 140-170) ==="
sed -n '140,170p' tests/ops/test_kda.py

echo -e "\n=== Count test cases with safe_gate=True across both test functions ==="
echo "test_chunk with safe_gate=True:"
sed -n '140,170p' tests/ops/test_kda.py | grep -c "True.*True" || echo "Checking..."
sed -n '140,170p' tests/ops/test_kda.py | tail -20

echo -e "\ntest_chunk_varlen with safe_gate=True:"
sed -n '273,278p' tests/ops/test_kda.py

Repository: fla-org/flash-linear-attention

Length of output: 2845


Add a descriptive message to the safe_gate assertion in the varlen test.

The assertion assert use_gate_in_kernel is False correctly enforces a varlen-specific constraint (unlike the non-varlen test_chunk which permits this combination), but it lacks an error message explaining why in-kernel gate computation is unsupported with safe gating for variable-length sequences. Update to: assert use_gate_in_kernel is False, "safe_gate is not supported with in-kernel gate computation for varlen".

🤖 Prompt for AI Agents
In tests/ops/test_kda.py around lines 311 to 313, the assertion for safe_gate
lacks a descriptive error message; update the assertion to include a clear
message such as "safe_gate is not supported with in-kernel gate computation for
varlen" so that if the test fails it explains that in-kernel gate computation is
unsupported with safe gating for variable-length sequences.

Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
tests/ops/test_kda.py (1)

311-313: Add a descriptive message to the safe_gate assertion.

The assertion assert use_gate_in_kernel is False correctly enforces a constraint, but it lacks an error message explaining why in-kernel gate computation is unsupported with safe gating for variable-length sequences.

🔎 Proposed fix
         if safe_gate:
-            assert use_gate_in_kernel is False
+            assert use_gate_in_kernel is False, "safe_gate is not supported with in-kernel gate computation for varlen"
             g = g.clamp(-5, 0)
fla/ops/kda/chunk.py (1)

219-221: Fix typo in assertion message.

"sage_gate" should be "safe_gate".

🔎 Proposed fix
         if safe_gate:
-            assert lower_bound is not None, "lower_bound must be set when use sage_gate"
+            assert lower_bound is not None, "lower_bound must be set when using safe_gate"
🧹 Nitpick comments (1)
fla/ops/kda/chunk.py (1)

259-260: Use idiomatic boolean comparison.

Per Python style guidelines (PEP 8), prefer not disable_recompute over disable_recompute is False for boolean checks.

🔎 Proposed fix
-        if disable_recompute is False and use_gate_in_kernel:
+        if not disable_recompute and use_gate_in_kernel:
             g = None  # type: ignore
-        if ctx.disable_recompute is False and ctx.use_gate_in_kernel:
+        if not ctx.disable_recompute and ctx.use_gate_in_kernel:
             g = kda_gate_chunk_cumsum(

Also applies to: 288-289

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c4a9f93 and e9a0605.

📒 Files selected for processing (7)
  • fla/ops/kda/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_intra.py
  • fla/ops/kda/chunk_intra_token_parallel.py
  • fla/ops/kda/gate.py
  • fla/ops/kda/wy_fast.py
  • tests/ops/test_kda.py
✅ Files skipped from review due to trivial changes (1)
  • fla/ops/kda/gate.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.

Applied to files:

  • fla/ops/kda/chunk_intra.py
🧬 Code graph analysis (3)
tests/ops/test_kda.py (2)
fla/ops/kda/gate.py (3)
  • naive_kda_lowerbound_gate (50-62)
  • naive_kda_gate (19-47)
  • backward (301-310)
fla/ops/kda/chunk.py (1)
  • backward (278-350)
fla/ops/kda/chunk.py (5)
fla/ops/kda/gate.py (1)
  • kda_gate_bwd (234-272)
fla/ops/kda/wy_fast.py (1)
  • recompute_w_u_fwd (212-256)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (429-469)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (112-122)
fla/ops/kda/chunk_intra.py (1)
  • chunk_kda_fwd_intra (739-833)
fla/ops/kda/chunk_intra.py (3)
fla/ops/utils/op.py (1)
  • gather (25-31)
fla/ops/kda/chunk_intra_token_parallel.py (2)
  • grid (152-152)
  • chunk_kda_fwd_intra_token_parallel (117-169)
fla/ops/kda/wy_fast.py (1)
  • recompute_w_u_fwd (212-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (20)
fla/ops/kda/chunk_bwd.py (1)

220-220: LGTM!

Adding this cautionary comment to the tl.debug_barrier() is a good practice. Synchronization barriers in Triton kernels are critical for correctness, and this annotation helps prevent accidental removal during future refactoring.

fla/ops/kda/wy_fast.py (3)

22-22: Adding tf32 precision option.

Adding "tf32" to the DOT_PRECISION choices alongside "tf32x3" and "ieee" is a reasonable extension. This provides an additional precision-performance tradeoff option. Based on a past review comment, this may improve performance but could reduce numerical precision compared to tf32x3.


28-28: LGTM - Kernel renaming to KDA-specific variants.

Renaming recompute_w_u_fwd_kernelrecompute_w_u_fwd_kda_kernel and prepare_wy_repr_bwd_kernelprepare_wy_repr_bwd_kda_kernel clearly indicates these are KDA-specific kernel implementations. This improves code organization and prevents confusion with potential non-KDA variants.

Also applies to: 116-116


221-256: LGTM - Extended return type for recompute_w_u_fwd.

The function now returns (w, u, qg, kg) instead of just (w, u), properly propagating the gated query and key tensors needed when disable_recompute=True. The conditional allocation of qg and kg based on input presence (lines 233-234) is memory-efficient.

tests/ops/test_kda.py (4)

150-167: LGTM - Test parameterization extended for new features.

Good coverage of safe_gate and disable_recompute combinations across different test configurations. The test IDs properly include the new parameters for clear identification.


196-203: LGTM - Safe gate handling with lower_bound and naive gate function selection.

Correctly sets lower_bound = -5.0 when safe_gate=True and selects the appropriate naive gate function (naive_kda_lowerbound_gate vs naive_kda_gate) for reference comparison. The clamping to [-5, 0] when not using in-kernel gate ensures consistency with the safe gate assumptions.


263-264: Note the tolerance adjustments for gradient checks.

The tolerances for dA (0.003) and dbias (0.008) are relatively tight. Consider whether these might need adjustment if tests become flaky on different hardware configurations.


355-356: Verify reference implementation uses correct gate function for varlen tests.

When use_gate_in_kernel=True in varlen tests, the reference uses naive_kda_gate but not naive_kda_lowerbound_gate even when safe_gate could be True. However, the current parameterization at lines 273-277 shows safe_gate=False when use_gate_in_kernel=True for varlen, so this is currently safe. If future tests combine safe_gate=True with use_gate_in_kernel=True in varlen, this would need updating.

fla/ops/kda/chunk_intra.py (6)

9-11: LGTM - Import updates for new functionality.

The imports are correctly updated to include prepare_chunk_indices from utils, and gather with IS_GATHER_SUPPORTED for the conditional gather-based data access paths.


14-16: Precision change from tf32x3 to tf32.

This change from 'tf32x3' to 'tf32' may improve performance but could reduce numerical precision. This aligns with the test tolerance adjustments observed in test_kda.py.


244-277: Safe gate path skips forward substitution on diagonals.

When USE_SAFE_GATE=True, the forward substitution loop (lines 253-277) is skipped. This relies on the chunk_kda_fwd_kernel_intra_sub_chunk kernel having already computed the appropriate diagonal inverse. Ensure this assumption is validated by the test coverage.


464-504: Safe gate backward path uses matrix operations instead of element-wise loop.

The SAFE_GATE branch replaces the element-wise loop with vectorized matrix operations using exp2(b_g - b_gn). This should be more efficient on tensor cores. The masking logic (m_i_diag_qk, m_j_diag_qk) correctly handles boundary conditions.


623-737: New intra-sub-chunk kernel for safe gate path.

This new kernel chunk_kda_fwd_kernel_intra_sub_chunk handles the safe gate computation at sub-chunk granularity. Key observations:

  1. Lines 698-703: Uses a pivot point at BC//2 for numerical stability, keeping gate differences small
  2. Lines 729-736: Performs forward substitution in-place after storing initial values
  3. The kernel correctly handles boundary conditions with m_c masking

The tl.debug_barrier() at line 723 ensures stores complete before the forward substitution reads from global memory.


767-801: Safe gate path selection in forward wrapper.

Good conditional dispatch between the new chunk_kda_fwd_kernel_intra_sub_chunk for safe_gate=True and the existing chunk_kda_fwd_intra_token_parallel for the standard path. The grid dimensions are correctly adjusted for the sub-chunk kernel.

fla/ops/kda/chunk.py (5)

2-2: LGTM - Attribution added.

Good practice to acknowledge contributing teams.


72-76: Memory cleanup for non-recompute path.

Setting intermediate tensors to None when disable_recompute=False is good for memory efficiency. The extended return tuple properly propagates all needed values for the backward pass when disable_recompute=True.


94-123: Conditional recomputation in backward pass.

The backward path correctly handles both cases:

  • When disable_recompute=False: recomputes w, u, qg, kg, h, v_new
  • When disable_recompute=True: retrieves saved values from kwargs

This provides a memory-compute tradeoff controlled by the user.


349-350: Verify backward return count matches forward input count.

The backward returns 17 values. Verify this matches the forward's positional arguments:

  1. dq, 2. dk, 3. dv, 4. dg, 5. db, 6. dA, 7. dbias, 8. None (scale), 9. dh0, 10-17. None values for remaining args.

The forward has: q, k, v, g, beta, A_log, dt_bias, scale, initial_state, output_final_state, use_qk_l2norm_in_kernel, use_gate_in_kernel, cu_seqlens, cu_seqlens_cpu, safe_gate, lower_bound, disable_recompute (17 args).

The count matches.


406-420: LGTM - Well-documented new parameters.

The docstring properly documents the new cu_seqlens_cpu, safe_gate, lower_bound, and disable_recompute parameters with clear descriptions of their purpose and recommended usage patterns.

fla/ops/kda/chunk_intra_token_parallel.py (1)

106-114: No action required—the index changes are correct.

The review comment incorrectly identifies the index change. The code actually uses:

  • Line 113: j % BT for the Aqk buffer (token-level, size BT)
  • Line 114: j - i_ts for the Akk buffer (sub-chunk-level, size BC)

These are not interchangeable indices—they serve different purposes for buffers with different logical sizes. Both formulas are mathematically correct: j % BT maps the loop variable j to its position within a token, while j - i_ts gives the offset within the sub-chunk (where j ranges from i_ts to min(i_t + 1, min(T, i_ts + BC))). The computational changes (masking approach and multiplicative conditional) are also correct.

Comment on lines +690 to +696
if USE_GATHER:
b_gn = gather(b_g, tl.full([1, BK], min(BC//2, T - i_ti - 1), dtype=tl.int16), axis=0)
else:
# caculate offset
p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H*K + tl.arange(0, BK)
b_gn = tl.load(p_gn, mask=tl.arange(0, BK) < K, other=0.0)
b_gn = b_gn[None, :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how IS_GATHER_SUPPORTED is defined
rg -n "IS_GATHER_SUPPORTED" --type py -B3 -A3 | head -60

Repository: fla-org/flash-linear-attention

Length of output: 3259


🏁 Script executed:

#!/bin/bash
# Find the gather function in utils/op.py
rg -n "def gather" fla/ops/utils/ --type py -B2 -A15

Repository: fla-org/flash-linear-attention

Length of output: 1093


🏁 Script executed:

#!/bin/bash
# Check the full context around lines 690-696 in chunk_intra.py
sed -n '680,705p' fla/ops/kda/chunk_intra.py | cat -n

Repository: fla-org/flash-linear-attention

Length of output: 1282


🏁 Script executed:

#!/bin/bash
# Check for USE_GATHER definition and how it's set
rg -n "USE_GATHER\s*=" fla/ops/kda/chunk_intra.py --type py -B2 -A2

Repository: fla-org/flash-linear-attention

Length of output: 321


🏁 Script executed:

#!/bin/bash
# Search for tests related to chunk_intra
find . -name "*test*" -type f -name "*.py" | xargs grep -l "chunk_intra" 2>/dev/null | head -10

Repository: fla-org/flash-linear-attention

Length of output: 56


🏁 Script executed:

#!/bin/bash
# Check if there are tests that explicitly test USE_GATHER=False or IS_GATHER_SUPPORTED
rg -l "USE_GATHER|IS_GATHER_SUPPORTED" --type py tests/ 2>/dev/null | head -20

Repository: fla-org/flash-linear-attention

Length of output: 56


🏁 Script executed:

#!/bin/bash
# Find all test files in the repository
find . -type f -name "test_*.py" -o -name "*_test.py" | head -20

Repository: fla-org/flash-linear-attention

Length of output: 896


🏁 Script executed:

#!/bin/bash
# Look for tests directory
find . -type d -name "tests" -o -type d -name "test" | head -10

Repository: fla-org/flash-linear-attention

Length of output: 81


🏁 Script executed:

#!/bin/bash
# Check if tests directory exists and what's in it
if [ -d "tests" ]; then ls -la tests/ 2>/dev/null | head -20; else echo "No tests directory found"; fi

Repository: fla-org/flash-linear-attention

Length of output: 350


🏁 Script executed:

#!/bin/bash
# Check the KDA model test file
cat -n tests/models/test_modeling_kda.py | head -100

Repository: fla-org/flash-linear-attention

Length of output: 1965


🏁 Script executed:

#!/bin/bash
# Look for ops tests related to chunk_intra or gather
ls -la tests/ops/ | grep -i kda

Repository: fla-org/flash-linear-attention

Length of output: 135


🏁 Script executed:

#!/bin/bash
# Check what's in ops tests directory
ls -la tests/ops/ | head -30

Repository: fla-org/flash-linear-attention

Length of output: 1991


🏁 Script executed:

#!/bin/bash
# Check test_kda.py for any tests related to gather or USE_GATHER
rg -n "gather|USE_GATHER|IS_GATHER" tests/ops/test_kda.py --type py

Repository: fla-org/flash-linear-attention

Length of output: 56


🏁 Script executed:

#!/bin/bash
# Check the full test_kda.py file to understand test coverage
cat -n tests/ops/test_kda.py | head -150

Repository: fla-org/flash-linear-attention

Length of output: 5815


🏁 Script executed:

#!/bin/bash
# Continue viewing test_kda.py to see all tests
wc -l tests/ops/test_kda.py

Repository: fla-org/flash-linear-attention

Length of output: 99


🏁 Script executed:

#!/bin/bash
# Check if there are any environment variable overrides or mocking for IS_GATHER_SUPPORTED in tests
rg -n "IS_GATHER_SUPPORTED|gather|USE_GATHER" tests/ -A3 -B3 --type py

Repository: fla-org/flash-linear-attention

Length of output: 1560


🏁 Script executed:

#!/bin/bash
# Check the conftest.py file for any fixtures related to gather
cat -n tests/ops/conftest.py

Repository: fla-org/flash-linear-attention

Length of output: 7928


Ensure the USE_GATHER=False fallback path is tested.

The scalar load fallback at lines 694-696 is correctly implemented and activates when IS_GATHER_SUPPORTED=False. However, the test suite (test_kda.py) does not explicitly exercise this fallback path—tests always use the IS_GATHER_SUPPORTED value at runtime. On hardware with gather support, this fallback branch never executes. Consider adding a test that mocks IS_GATHER_SUPPORTED=False or parameterizes tests to cover both paths.

🤖 Prompt for AI Agents
In fla/ops/kda/chunk_intra.py around lines 690 to 696, the fallback branch used
when USE_GATHER/IS_GATHER_SUPPORTED is False (the scalar tl.load path) is not
covered by tests; update tests to explicitly exercise this branch by adding or
modifying test_kda.py to run the kernel with IS_GATHER_SUPPORTED
mocked/overridden to False (or parameterize the test to run with both True and
False), verify outputs/behavior match the gather-enabled path, and ensure any
necessary seeding or environment setup is applied so the scalar load path
executes deterministically during the test.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
fla/ops/kda/chunk.py (1)

483-487: Comprehensive validation for safe_gate mode.

The validation correctly ensures that:

  • lower_bound is specified when safe_gate=True
  • lower_bound is within the safe range [-5, 0)

This prevents runtime errors from invalid configurations.

The static analysis tool suggests extracting the long exception messages into constants or custom exception classes to comply with TRY003. However, this is a minor style preference, and the current approach provides clear, contextualized error messages.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9a0605 and 2e003e3.

📒 Files selected for processing (3)
  • benchmarks/ops/benchmark_kda.py
  • fla/ops/kda/chunk.py
  • tests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/ops/test_kda.py
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/ops/benchmark_kda.py (2)
fla/ops/kda/chunk.py (1)
  • backward (278-347)
fla/ops/kda/gate.py (1)
  • backward (301-310)
🪛 Ruff (0.14.10)
fla/ops/kda/chunk.py

485-485: Avoid specifying long messages outside the exception class

(TRY003)


487-487: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (9)
benchmarks/ops/benchmark_kda.py (1)

44-44: LGTM: Direct provider comparison is clearer.

The switch from provider_base == 'xxx' to provider == 'xxx' simplifies the logic and improves readability.

Also applies to: 61-61, 73-73, 92-92, 110-110

fla/ops/kda/chunk.py (8)

2-2: LGTM: Import updates align with new gating logic.

The copyright update and import changes (switching to kda_gate_chunk_cumsum and adding prepare_chunk_indices) correctly support the new lower-bound-aware chunked gating functionality.

Also applies to: 11-11, 15-15


31-32: LGTM: Memory-efficient recomputation strategy.

The conditional memory cleanup at lines 72-76 provides a good trade-off: by default (disable_recompute=False), intermediates are freed to save memory and recomputed during backward pass. When disable_recompute=True, intermediates are preserved for faster backward computation at the cost of memory.

Also applies to: 34-47, 72-76


94-96: LGTM: Backward pass correctly handles recomputation.

The conditional recomputation logic at lines 98-123 properly mirrors the forward pass strategy. When disable_recompute=False, intermediates are recomputed; otherwise, they're retrieved from saved kwargs. The use of **kwargs at line 96 provides good backward compatibility.

Also applies to: 98-123, 184-184


208-211: LGTM: Forward pass properly implements safe gating with lower bound.

The forward pass correctly:

  • Computes chunk indices from cu_seqlens when provided
  • Validates safe_gate requirements (assertion at line 220, already addressed per past comments)
  • Applies kda_gate_chunk_cumsum with lower_bound when safe gating is enabled
  • Conditionally cleans up intermediate g to save memory
  • Extends saved tensors to include all intermediates needed for the backward pass

Also applies to: 213-238, 244-257, 259-265, 267-267, 269-269, 272-272


283-287: LGTM: Backward pass correctly propagates new state and flags.

The backward pass properly:

  • Unpacks the extended saved tensors including intermediates
  • Conditionally recomputes the gate when needed (lines 288-298)
  • Threads safe_gate, disable_recompute, and all intermediates through to chunk_kda_bwd
  • Correctly applies reverse cumsum in both gating branches

Also applies to: 288-298, 314-316, 324-345


357-366: LGTM: Well-documented API extensions.

The public API signature is properly extended with type hints, and the docstring comprehensively documents the new parameters (cu_seqlens_cpu, safe_gate, lower_bound, disable_recompute) including their interactions and constraints.

Also applies to: 403-417


495-513: LGTM: ChunkKDAFunction.apply properly wired.

All new parameters (cu_seqlens_cpu, safe_gate, lower_bound, disable_recompute) are correctly passed to the autograd function in the proper order.


490-490: Assertion correctly enforces kernel's head dimension limit.

The assertion assert k.shape[-1] <= 256 is properly aligned with the underlying kernel constraints. The kernel chunk_gated_delta_rule_fwd_kernel_h_blockdim64 in fla/ops/common/chunk_delta_h.py enforces the same limit (lines 488, 536), with internal structure supporting up to 256 dimensions via four 64-element blocks.

Comment thread benchmarks/ops/benchmark_kda.py
load g 的时候手动 to fp32,解耦对 fp32 g 的需求
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
fla/ops/kda/chunk_intra.py (1)

690-696: Ensure the USE_GATHER=False fallback path is tested.

The scalar load fallback at lines 694-696 is correctly implemented and activates when IS_GATHER_SUPPORTED=False. As noted in a past review, the test suite doesn't explicitly exercise this fallback. On hardware with gather support, this branch never executes. Consider adding a test that parameterizes IS_GATHER_SUPPORTED to cover both paths.

🧹 Nitpick comments (3)
fla/ops/kda/chunk_bwd.py (1)

220-220: Barrier retention is important for correctness.

The explicit "DO NOT REMOVE" comment on tl.debug_barrier() suggests this is required for correct synchronization behavior due to Triton compiler sensitivities. Consider adding a brief comment explaining why removal breaks correctness (e.g., race condition on shared state, compiler optimization issue).

fla/ops/common/chunk_delta_h.py (2)

145-162: Consider explicitly casting b_g_last to FP32 for consistency.

Line 149 explicitly casts b_g to tl.float32, but line 147 loads b_g_last without an explicit cast. The subsequent subtraction b_g_last - b_g (lines 151, 154) will trigger implicit casting, but for consistency and clarity, consider explicitly casting b_g_last:

b_g_last = tl.load(g + bos * H + last_idx * H + i_h).to(tl.float32)

This ensures all gate values uniformly use FP32 for numerical operations and makes the dtype handling more explicit.


335-344: Consider explicitly casting bg_last to FP32 for consistency.

Similar to the forward kernel, line 338 explicitly casts b_g to tl.float32, but line 336 loads bg_last without an explicit cast. For consistency with the explicit FP32 casting pattern and to avoid implicit type conversions in bg_last - b_g (line 387/389), consider:

bg_last = tl.load(g + (bos + last_idx) * H + i_h).to(tl.float32)
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2e003e3 and 86867ca.

📒 Files selected for processing (5)
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/gla/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_intra.py
  • fla/ops/kda/wy_fast.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.

Applied to files:

  • fla/ops/kda/chunk_intra.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (17)
fla/ops/gla/chunk.py (1)

352-352: LGTM! Explicit float32 casting ensures numerical precision.

Casting b_g to tl.float32 before exponential operations ensures consistent precision regardless of input dtype, aligning with the broader PR objective of supporting FP16 gates while preserving numerical behavior.

fla/ops/kda/chunk_bwd.py (1)

191-194: LGTM! Explicit float32 casting for gate values.

Both b_g and b_gn are now explicitly cast to tl.float32, ensuring numerical stability for subsequent exp2 operations regardless of the input tensor's dtype.

fla/ops/kda/wy_fast.py (4)

22-22: Verify the tf32 precision option aligns with numerical requirements.

The autotuning now includes "tf32" in addition to "tf32x3" and "ieee". Note that tf32 provides lower precision than tf32x3 (which uses 3 rounds of TF32). Given that this PR aims to maintain numerical stability, confirm that the autotuner selecting tf32 over tf32x3 won't degrade precision below acceptable thresholds for the use cases.


28-28: Kernel renamed to clarify KDA-specific usage.

The rename from recompute_w_u_fwd_kernel to recompute_w_u_fwd_kda_kernel improves clarity. Ensure any external consumers (if any) are updated accordingly.


82-82: LGTM! Float32 casting for gate computations.

Explicit .to(tl.float32) casts on b_gk and b_gn ensure numerical stability in exponential operations, consistent with the PR's objective of supporting FP16 gates.

Also applies to: 94-94


221-221: Extended return type enables recompute optimization.

The function now returns a 4-tuple (w, u, qg, kg) where qg and kg are conditionally computed based on disable_recompute. This allows the caller to cache precomputed gated values and avoid redundant computation in the backward pass.

Also applies to: 233-234, 256-256

fla/ops/kda/chunk_intra.py (8)

14-16: Precision change from tf32x3 to tf32.

Changing SOLVE_TRIL_DOT_PRECISION from 'tf32x3' to 'tf32' may improve performance but can reduce numerical precision. This aligns with a past review comment noting this tradeoff. Given the PR's test tolerance adjustments, this appears intentional.


244-277: Forward substitution correctly bypassed for safe-gate path.

When USE_SAFE_GATE is true, the iterative forward substitution loop is skipped. This is correct because the new chunk_kda_fwd_kernel_intra_sub_chunk kernel (lines 636-737) handles diagonal block inversion internally, so Akkd already contains the inverted blocks when the safe-gate path is used.


464-504: Safe-gate backward path uses block-wise computation for better Tensor Core utilization.

The SAFE_GATE branch replaces per-element loops with block-wise matrix operations (tl.dot), which better utilizes Tensor Cores. The median-shift approach (BC//2) for numerical stability is consistent with the forward path.


560-609: Symmetric safe-gate handling in the dk backward computation.

The second SAFE_GATE block mirrors the first, computing b_dkt using block-wise operations instead of element-wise loops. The consistent use of min(BC//2, T - i_ti - 1) for the median index ensures numerical stability.


729-736: Forward substitution in sub-chunk kernel produces inverted diagonal blocks.

The loop correctly computes the inverse of the lower-triangular diagonal block in-place. Adding the identity matrix (m_I) at the end and overwriting Akk ensures the output is the full inverse. This enables the inter-solve kernel to skip its forward substitution when USE_SAFE_GATE=True.


767-801: Conditional kernel dispatch based on safe_gate.

The safe-gate path uses the new chunk_kda_fwd_kernel_intra_sub_chunk kernel which computes diagonal blocks with improved numerical stability via median shift, while the non-safe-gate path continues using the token-parallel kernel. This branching correctly isolates the new behavior.


823-833: Extended return signature propagates precomputed gated values.

The function now returns (w, u, qg, kg, Aqk, Akk) where qg and kg are conditionally computed when disable_recompute=True. This allows the forward pass to cache these values for the backward pass, avoiding redundant computation.


892-898: Safe-gate and gather support flags correctly propagated to backward kernel.

Both SAFE_GATE=safe_gate and USE_GATHER=IS_GATHER_SUPPORTED are passed to chunk_kda_bwd_kernel_intra, ensuring the backward path uses matching logic to the forward path.

fla/ops/common/chunk_delta_h.py (3)

164-192: LGTM! Consistent FP32 casting for gk loads.

All gk tensor loads across the four 64-element key blocks are explicitly cast to tl.float32. This ensures numerical stability in subsequent exp/exp2 operations and enables FP16 gate support as intended.


355-382: LGTM! Backward kernel gk casting is consistent.

All gk loads in the backward kernel are explicitly cast to tl.float32, mirroring the forward kernel pattern. This maintains numerical consistency and enables FP16 gate support throughout the forward and backward passes.


149-149: Excellent approach to enable FP16 gate support.

The explicit .to(tl.float32) casts after loading gate tensors effectively decouple the input tensor dtype from the computation dtype. This allows FP16 gates to be used for memory efficiency while maintaining FP32 precision for the critical exponential operations, preserving numerical stability. The pattern is consistently applied across both forward and backward kernels for all gk loads.

Also applies to: 166-166, 173-173, 180-180, 187-187, 338-338, 357-357, 365-365, 373-373, 381-381

@zhiyuan1i zhiyuan1i merged commit d1097c6 into main Dec 30, 2025
1 of 2 checks passed
@zhiyuan1i zhiyuan1i deleted the msh/lowerbound-kda branch December 30, 2025 13:57
@wfloveiu
Copy link
Copy Markdown

Good optimization. Here is a question, in chunk_intra.py: if not USE_SAFE_GATE:,if set safe_gate=True,here dosn't do forward substitution on diagonals,and directly compute off-diagonals subblocks. So where do you do the inverse of diagonals?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants