Skip to content

[KDA][GDN] Support transpose_state_layout for [V,K] state memory layout#776

Merged
zhiyuan1i merged 1 commit intomainfrom
lzy/transpose-kv
Mar 10, 2026
Merged

[KDA][GDN] Support transpose_state_layout for [V,K] state memory layout#776
zhiyuan1i merged 1 commit intomainfrom
lzy/transpose-kv

Conversation

@zhiyuan1i
Copy link
Copy Markdown
Collaborator

@zhiyuan1i zhiyuan1i commented Mar 10, 2026

Add transpose_state_layout parameter to chunk, fused_recurrent, and context parallel paths for both KDA and GDN. When enabled, all state tensors use [V,K] layout instead of [K,V] to improve memory access patterns.

Summary by CodeRabbit

  • New Features

    • Added an optional transpose_state_layout flag (default False) across multiple ops and fused/chunk workflows to select an alternate state tensor layout for forward/backward paths and final-state outputs. Preserves existing behavior when disabled and is threaded through kernel invocations and public APIs.
  • Tests

    • Added unit and multi-GPU context-parallel tests covering outputs, gradients, and final-state layouts with transpose_state_layout enabled.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 10, 2026

Walkthrough

Adds an optional transpose_state_layout: bool flag threaded through Python APIs, Intracard/CP backends, Triton kernels (TRANSPOSE_STATE constexpr / autotune keys), and tests to enable an alternate transposed internal state layout; default behavior unchanged when False.

Changes

Cohort / File(s) Summary
Backend & Intracard
fla/ops/common/backends/intracard.py, fla/ops/common/intracard_cp.py
Added transpose_state_layout parameter to IntraCard CP APIs and forwarded it into intracard_fwd_h/merge paths; conditional allocations and shape handling for transposed vs standard state layouts; propagated flag to kernel launches.
Common chunk kernels & o-path
fla/ops/common/chunk_delta_h.py, fla/ops/common/chunk_o.py
Introduced TRANSPOSE_STATE constexpr and transpose_state_layout wiring: kernel signatures, autotune keys, pointer/layout branching, tl.trans usage, and grid/meta dispatch updated to pass the flag.
CP merged kernels
fla/ops/cp/chunk_delta_h.py
Extended merged kernels and pre/post-processors to accept TRANSPOSE_STATE with layout-dependent buffer shapes, pointer arithmetic, and recurrence/load/store paths.
Gated-delta & fused recurrent
fla/ops/gated_delta_rule/chunk.py, fla/ops/gated_delta_rule/fused_recurrent.py
Threaded transpose_state_layout through forward/backward, autograd ctx, and fused kernels; final-state allocations and per-step logic branch on layout flag.
GLA & KDA operators
fla/ops/gla/chunk.py, fla/ops/kda/chunk.py, fla/ops/kda/chunk_fwd.py, fla/ops/kda/chunk_bwd.py, fla/ops/kda/fused_recurrent.py
Exposed transpose_state_layout/TRANSPOSE_STATE across public APIs and kernels; updated autotune keys, pointer construction, and both fwd/bwd paths to respect transposed layout.
Autotune / kernel keys
fla/ops/*/*_*.py (various kernel modules)
Autotune keys updated to include TRANSPOSE_STATE for affected kernels; kernel signatures extended with TRANSPOSE_STATE: tl.constexpr.
IntraCard CP backend facade
fla/ops/common/backends/intracard.py
Updated IntraCardCPBackend methods to accept transpose_state_layout and forward it into intracard_fwd_h verifier / call sites.
Tests (CP & ops)
tests/context_parallel/test_cp_gdn.py, tests/context_parallel/test_cp_kda.py, tests/ops/test_gated_delta.py, tests/ops/test_kda.py
Added/updated tests to exercise transpose_state_layout in CP (2/4) and operator suites; new parameterized tests validate forward/backward equivalence and final-state transpositions between layouts.

Sequence Diagram(s)

sequenceDiagram
  participant User as Operator call
  participant PyAPI as Python API / Wrapper
  participant Backend as Intracard/CP Backend
  participant Kernel as Triton Kernel
  participant Device as GPU Memory

  User->>PyAPI: call op(..., transpose_state_layout=flag)
  PyAPI->>Backend: forward flag, prepare/reshape tensors (K,V) or (V,K)
  Backend->>Kernel: launch kernel with TRANSPOSE_STATE=flag
  Kernel->>Device: read/write state buffers (layout depends on flag)
  Kernel-->>Backend: return outputs + final_state (shaped per flag)
  Backend-->>PyAPI: return outputs (final_state in chosen layout)
  PyAPI-->>User: return results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

🐰 A hop, a flip — V greets K,

Bits rearranged in merry play.
Default stays for those who rest,
But transposed states now pass the test.
Kernels hum; the rabbit’s pleased today.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.58% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the main change: adding transpose_state_layout support for [V,K] state memory layout to KDA and GDN implementations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch lzy/transpose-kv

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization by allowing state tensors in KDA and GDN operations to adopt a transposed memory layout. This change aims to enhance memory access patterns, which can be crucial for performance-sensitive applications. The new transpose_state_layout parameter provides flexibility for users to configure state tensor storage in chunk, fused recurrent, and context parallel implementations, enabling better hardware utilization.

Highlights

  • New Feature: Transposed State Layout: Introduced a transpose_state_layout parameter across various operations to allow state tensors to be stored in a [V,K] memory layout instead of the default [K,V].
  • Improved Memory Access Patterns: The primary motivation for this change is to optimize memory access patterns, potentially leading to performance gains in specific scenarios for KDA and GDN models.
  • Broad Application: The transpose_state_layout parameter has been integrated into chunk, fused_recurrent, and context parallel paths for both KDA and GDN operations, ensuring consistent behavior across these components.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • fla/ops/common/backends/intracard.py
    • Added transpose_state_layout parameter to chunk_gated_delta_rule_fwd_h_verifier and chunk_gated_delta_rule_fwd_h functions.
    • Passed transpose_state_layout to the intracard_fwd_h call.
  • fla/ops/common/chunk_delta_h.py
    • Included TRANSPOSE_STATE in the autotune keys for forward and backward kernels.
    • Modified state tensor initialization and block pointer creation to conditionally use [V, K] layout based on TRANSPOSE_STATE.
    • Adjusted matrix multiplication and scaling operations to correctly handle transposed state tensors.
  • fla/ops/common/chunk_o.py
    • Included TRANSPOSE_STATE in the autotune keys and as a constexpr parameter for forward and backward kernels.
    • Modified block pointer creation and dot product operations to support [V, K] state layout.
    • Added transpose_state_layout parameter to chunk_fwd_o and chunk_bwd_dqkwg functions.
  • fla/ops/common/intracard_cp.py
    • Added transpose_state_layout parameter to _raw_chunk_gated_delta_rule_fwd_h and intracard_merge functions.
    • Modified state tensor initialization and creation to respect the transpose_state_layout flag.
    • Passed TRANSPOSE_STATE to relevant kernel calls.
  • fla/ops/cp/chunk_delta_h.py
    • Updated comments for h and h0 parameters in merge_fwd_bwd_kernel to reflect [V, K] layout support.
    • Added TRANSPOSE_STATE constexpr parameter to merge_fwd_bwd_kernel and updated its docstring.
    • Adjusted b_h initialization, block pointer creation, and matrix multiplication logic to handle transposed states.
    • Added transpose_state_layout parameter to chunk_gated_delta_rule_fwd_h_pre_process and chunk_gated_delta_rule_bwd_dhu_pre_process.
  • fla/ops/gated_delta_rule/chunk.py
    • Added transpose_state_layout parameter to chunk_gated_delta_rule_fwd and chunk_gated_delta_rule_bwd functions.
    • Propagated transpose_state_layout through various internal function calls and stored it in the ctx object for backward pass.
    • Added transpose_state_layout to the main chunk_gated_delta_rule function.
  • fla/ops/gated_delta_rule/fused_recurrent.py
    • Added TRANSPOSE_STATE constexpr parameter to fused_recurrent_gated_delta_rule_fwd_kernel.
    • Modified mask_h calculation, b_h initialization, p_h0 and p_ht calculations, and dot product operations to support transposed states.
    • Added transpose_state_layout parameter to fused_recurrent_gated_delta_rule_fwd and fused_recurrent_gated_delta_rule.
    • Updated the docstring for fused_recurrent_gated_delta_rule to describe the new parameter.
  • fla/ops/gla/chunk.py
    • Included TRANSPOSE_STATE in the autotune key for chunk_gla_fwd_A_kernel_intra_sub_intra_merge and as a constexpr for chunk_gla_fwd_kernel_o.
    • Modified block pointer creation and dot product operations in chunk_gla_fwd_kernel_o to handle transposed states.
    • Added transpose_state_layout parameter to chunk_gla_fwd_o_gk.
  • fla/ops/kda/chunk.py
    • Added transpose_state_layout parameter to _ChunkKDA.forward and _ChunkKDA.backward methods, storing it in ctx.
    • Propagated transpose_state_layout to chunk_kda_fwd and chunk_kda_bwd calls.
    • Added transpose_state_layout to the main chunk_kda function.
  • fla/ops/kda/chunk_bwd.py
    • Included TRANSPOSE_STATE in the autotune key for chunk_kda_bwd_kernel_dAv and as a constexpr for chunk_kda_bwd_kernel_wy_dqkg_fused.
    • Modified block pointer creation for p_h and p_dh to support [V, K] layout.
    • Added transpose_state_layout parameter to chunk_kda_bwd_wy_dqkg_fused and chunk_kda_bwd.
  • fla/ops/kda/chunk_fwd.py
    • Added transpose_state_layout parameter to chunk_kda_fwd.
    • Propagated transpose_state_layout to internal calls like chunk_gated_delta_rule_fwd_h_pre_process, chunk_gated_delta_rule_fwd_h, and chunk_gla_fwd_o_gk.
  • fla/ops/kda/fused_recurrent.py
    • Added TRANSPOSE_STATE constexpr parameter to fused_recurrent_kda_fwd_kernel.
    • Modified mask_h calculation, b_h initialization, p_h0 and p_ht calculations, and dot product operations to support transposed states.
    • Added transpose_state_layout parameter to fused_recurrent_kda_fwd and fused_recurrent_kda.
    • Updated the docstring for fused_recurrent_kda to describe the new parameter.
  • tests/context_parallel/test_cp_gdn.py
    • Added transpose_state_layout parameter to run_cp_gdn_test_worker and run_cp_test_with_spawn.
    • Passed transpose_state_layout to chunk_gated_delta_rule in both reference and CP implementations.
    • Added new test cases test_cp2_transpose_state and test_cp4_transpose_state to validate the transposed state layout in context parallel GDN.
  • tests/context_parallel/test_cp_kda.py
    • Added transpose_state_layout parameter to run_cp_kda_test_worker and run_cp_test_with_spawn.
    • Passed transpose_state_layout to fused_recurrent_kda in both reference and CP implementations.
    • Added new test cases test_cp2_transpose_state and test_cp4_transpose_state to validate the transposed state layout in context parallel KDA.
  • tests/ops/test_gated_delta.py
    • Added new test cases test_chunk_transpose_state and test_fused_recurrent_transpose_state to verify the correctness of the transposed state layout for gated delta rule operations.
  • tests/ops/test_kda.py
    • Added new test cases test_fused_recurrent_transpose_state and test_chunk_transpose_state to verify the correctness of the transposed state layout for KDA operations.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for a transposed state memory layout ([V,K]) across various components for KDA and GDN to improve memory access patterns. The changes are extensive, consistently propagating the transpose_state_layout flag from high-level functions down to the Triton kernels. The logic for handling both memory layouts appears correct, and the PR includes new tests to verify this functionality.

My main feedback concerns the significant code duplication introduced in several Triton kernels. Many if/else blocks for the new TRANSPOSE_STATE flag repeat large chunks of code with only minor differences. While this might be necessary to some extent due to Triton's constraints, there are opportunities to refactor and reduce this duplication, which would greatly improve the code's readability and maintainability. I've added specific comments with suggestions on how to approach this refactoring.

Note: Security Review did not run due to the size of the PR.

Comment on lines +74 to +89
if TRANSPOSE_STATE:
b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([BV, 64], dtype=tl.float32)
else:
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication here for initializing the state tensors b_h1, b_h2, etc., based on TRANSPOSE_STATE. This can be refactored to improve readability and maintainability by defining the shape conditionally. Since TRANSPOSE_STATE is a constexpr, the compiler can optimize this, but the duplicated code makes it harder for developers to maintain.

Consider defining the shape dimensions conditionally and reusing them:

shape_dim0 = BV if TRANSPOSE_STATE else 64
shape_dim1 = 64 if TRANSPOSE_STATE else BV

b_h1 = tl.zeros([shape_dim0, shape_dim1], dtype=tl.float32)
if K > 64:
    b_h2 = tl.zeros([shape_dim0, shape_dim1], dtype=tl.float32)
# ... and so on

This pattern of duplication appears throughout the file and could be similarly refactored.

Comment on lines 105 to 128
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
if TRANSPOSE_STATE:
p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
else:
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
if TRANSPOSE_STATE:
p_h0_2 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0))
else:
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
if TRANSPOSE_STATE:
p_h0_3 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0))
else:
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
if TRANSPOSE_STATE:
p_h0_4 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0))
else:
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for loading the initial state has a lot of duplicated code due to the TRANSPOSE_STATE flag. This pattern of duplication appears in several other places in this kernel (e.g., main recurrence loop, final state storing). While tl.constexpr helps the compiler, it makes the code harder to read and maintain. A bug fix in one path might be missed in the other.

Consider refactoring to reduce this duplication. For example, you could define the parameters for tl.make_block_ptr conditionally at the beginning of the if USE_INITIAL_STATE: block, and then reuse them. This would make the logic for different K values much cleaner.

Comment on lines +380 to +395
if TRANSPOSE_STATE:
b_dh1 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 64:
b_dh2 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 128:
b_dh3 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 192:
b_dh4 = tl.zeros([BV, 64], dtype=tl.float32)
else:
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
if K > 64:
b_dh2 = tl.zeros([64, BV], dtype=tl.float32)
if K > 128:
b_dh3 = tl.zeros([64, BV], dtype=tl.float32)
if K > 192:
b_dh4 = tl.zeros([64, BV], dtype=tl.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the forward kernel, there's duplicated logic for initializing b_dh1, b_dh2, etc. This reduces readability and increases the chance of introducing bugs if one path is modified but the other is not. Please consider refactoring to consolidate the common logic and handle the differences based on TRANSPOSE_STATE more concisely.

Comment on lines 589 to +607
if HAS_H0:
orig_seq_id = tl.load(h0_seq_ids + i_seq).to(tl.int32)
p_h0 = tl.make_block_ptr(
h0 + (orig_seq_id * H + i_h) * K * V,
(K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0)
)
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
if TRANSPOSE_STATE:
p_h0 = tl.make_block_ptr(
h0 + (orig_seq_id * H + i_h) * V * K,
(V, K), (K, 1), (i_v * BV, 0), (BV, BK), (1, 0)
)
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
else:
p_h0 = tl.make_block_ptr(
h0 + (orig_seq_id * H + i_h) * K * V,
(K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0)
)
b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
else:
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if TRANSPOSE_STATE:
b_h = tl.zeros([BV, BK], dtype=tl.float32)
else:
b_h = tl.zeros([BK, BV], dtype=tl.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block for initializing the state has duplicated code for handling TRANSPOSE_STATE. This pattern continues in the merge loop. To improve maintainability, consider refactoring to reduce this duplication. You could, for instance, set up shape and offset-related variables conditionally at the beginning of the if INTRACARD_MODE: block and reuse them.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
fla/ops/cp/chunk_delta_h.py (1)

953-992: ⚠️ Potential issue | 🔴 Critical

Offset gk before these backward-decay loads.

These tl.load(gk + last_idx * H * K + ...) expressions never incorporate bos or i_h, so any non-zero sequence offset or head index reads decays from the wrong slice. In CP KDA backward this corrupts the preprocessed dht / initial_state for every head except the first one.

🩹 Proposed fix
     q += ((bos * H + i_h) * K).to(tl.int64)
     k += ((bos * H + i_h) * K).to(tl.int64)
     w += ((bos * H + i_h) * K).to(tl.int64)
+    if USE_GK:
+        gk += ((bos * H + i_h) * K).to(tl.int64)
     dhm += i_h * K * (V + K)
     stride_k = H * K
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/cp/chunk_delta_h.py` around lines 953 - 992, The backward loads for
gk use tl.load(gk + last_idx * H * K + o_...) which ignores sequence offset
(bos) and head index (i_h); update all tl.load calls that compute
b_gk_last1/2/3/4 to base the address on both bos and i_h (e.g. replace last_idx
* H * K with (bos + last_idx) * H * K + i_h * K) so each load becomes tl.load(gk
+ (bos + last_idx) * H * K + i_h * K + o_kN, ...), ensuring the gk slice is
correctly offset per sequence and head for b_gk_last1, b_gk_last2, b_gk_last3,
b_gk_last4.
🧹 Nitpick comments (1)
tests/ops/test_gated_delta.py (1)

236-296: Consider adding gradient verification for fused recurrent transpose test.

The test_fused_recurrent_transpose_state only verifies forward outputs. The chunk version includes gradient checks. Consider adding backward pass verification for completeness, especially since fused_recurrent_gated_delta_rule supports gradients.

💡 Suggested enhancement
 def test_fused_recurrent_transpose_state(
     ...
 ):
     ...
-    q, k, v, beta, g, h0_kv, h0_vk = map(lambda x: x.to(device), (q, k, v, beta, g, h0_kv, h0_vk))
+    q, k, v, beta, g, h0_kv, h0_vk = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, beta, g, h0_kv, h0_vk))

     ref, ref_ht = fused_recurrent_gated_delta_rule(...)
     tri, tri_ht = fused_recurrent_gated_delta_rule(...)
+    
+    do = torch.randn_like(ref)
+    dht_vk = torch.randn(B, HV, D, D, dtype=torch.float32, device=device)
+    dht_kv = dht_vk.transpose(-1, -2).contiguous()
+    
+    ((tri * do).sum() + (tri_ht * dht_vk).sum()).backward(retain_graph=True)
+    tri_dq, tri_dk, tri_dv = q.grad, k.grad, v.grad
+    q.grad = k.grad = v.grad = None
+    
+    ((ref * do).sum() + (ref_ht * dht_kv).sum()).backward(retain_graph=True)
+    ref_dq, ref_dk, ref_dv = q.grad, k.grad, v.grad
+    
+    assert_close('dq', ref_dq, tri_dq, 1e-4)
+    # ... additional gradient checks
+
     assert_close('o', ref, tri, 1e-4)
     assert_close('ht', ref_ht, tri_ht.transpose(-1, -2), 1e-4)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_gated_delta.py` around lines 236 - 296, The test
test_fused_recurrent_transpose_state only checks forward outputs; add
backward/gradient verification by computing a scalar loss (e.g., sum of outputs
or dot with random upstream grads) for both calls to
fused_recurrent_gated_delta_rule (the ref call with transpose_state_layout=False
and the tri call with transpose_state_layout=True), call backward() to get
gradients for inputs (q, k, v, beta, g and initial_state/h0_vk/h0_kv as
applicable), and assert that corresponding gradients match (e.g., compare
q.grad, k.grad, v.grad and initial_state.grad vs tri's grads after appropriate
transpose) within the same tolerance used for forward (use assert_close on
gradients); ensure you .clone().detach().requires_grad_(True) the inputs before
calling the functions so gradients are tracked.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Line 313: The wrapper currently accepts transpose_state_layout: bool and then
only checks initial_state.shape[0], which lets a mismatched layout ([N,H,K,V] vs
expected [N,H,V,K]) slip through; update the wrapper to validate
initial_state.ndim and the ordering of dimensions when transpose_state_layout is
True by asserting (or raising a clear ValueError) that initial_state has four
dims and that its shape matches the expected [N, H, V, K] layout (or the
non-transposed [N, H, K, V] when False), referencing the transpose_state_layout
flag and initial_state parameter so misuse fails fast rather than silently
producing wrong outputs/gradients.
- Line 238: ChunkGatedDeltaRuleFunction.forward was extended to accept
transpose_state_layout (making 13 inputs) but backward still returns only 12
gradients; update ChunkGatedDeltaRuleFunction.backward to include an extra
gradient placeholder (e.g., None) corresponding to the transpose_state_layout
input so the returned tuple length matches the forward inputs. Locate the
backward implementation and add the additional None in the correct position in
the returned gradients tuple (matching the transpose_state_layout parameter) to
avoid PyTorch's incorrect-gradient-count error.

In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Line 171: When transpose_state_layout is True in
fused_recurrent_gated_delta_rule_fwd (and the analogous call around lines
181-185), validate that initial_state has the expected tail shape/layout ([N,
HV, V, K]) before proceeding; detect mismatches where initial_state is in [N,
HV, K, V] (or any other incompatible tail shape) and raise a clear ValueError
indicating the expected vs actual tail dims. Implement this check immediately
after reading transpose_state_layout and before any pointer math or kernel calls
so the kernel never silently misreads a mismatched cache; reference the
transpose_state_layout parameter and the initial_state tensor in your error
message for clarity.

In `@fla/ops/kda/chunk.py`:
- Line 163: chunk_kda currently only validates sequence count and dtype but not
the cache memory layout, so if transpose_state_layout is toggled it can silently
reinterpret memory in chunk_kda_fwd / chunk_kda_bwd; update chunk_kda() to
detect and reject mismatched cache layouts by checking the cache tensor's
shape/order against the transpose_state_layout flag (e.g., if
transpose_state_layout is False expect [N, S, H, K, V] layout or if True expect
[N, H, K, V] etc.), and raise a clear error when the actual layout doesn't match
the flag before calling chunk_kda_fwd / chunk_kda_bwd; reference the
transpose_state_layout parameter and the cache input used by chunk_kda(), and
add the guard early in chunk_kda() so downstream fwd/bwd kernels never
reinterpret memory incorrectly.

In `@fla/ops/kda/fused_recurrent.py`:
- Line 239: The code accepts transpose_state_layout but never validates
initial_state layout, so a stale [N, HV, K, V] buffer can be silently
reinterpreted as [N, HV, V, K]; add an explicit fast-fail validation in both the
inplace and regular paths (where transpose_state_layout is used) that checks
initial_state's dimensions/order against the expected final layout when
transpose_state_layout is True (expected axes [N, HV, V, K]) and raise a clear
ValueError if they mismatch; alternatively, if safe, perform an explicit
transpose/reorder of initial_state into the required layout before caching/using
it—apply this check/reorder for the code paths around the transpose_state_layout
flag and any functions that consume initial_state so the kernel never receives a
stale layout.

---

Outside diff comments:
In `@fla/ops/cp/chunk_delta_h.py`:
- Around line 953-992: The backward loads for gk use tl.load(gk + last_idx * H *
K + o_...) which ignores sequence offset (bos) and head index (i_h); update all
tl.load calls that compute b_gk_last1/2/3/4 to base the address on both bos and
i_h (e.g. replace last_idx * H * K with (bos + last_idx) * H * K + i_h * K) so
each load becomes tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_kN, ...),
ensuring the gk slice is correctly offset per sequence and head for b_gk_last1,
b_gk_last2, b_gk_last3, b_gk_last4.

---

Nitpick comments:
In `@tests/ops/test_gated_delta.py`:
- Around line 236-296: The test test_fused_recurrent_transpose_state only checks
forward outputs; add backward/gradient verification by computing a scalar loss
(e.g., sum of outputs or dot with random upstream grads) for both calls to
fused_recurrent_gated_delta_rule (the ref call with transpose_state_layout=False
and the tri call with transpose_state_layout=True), call backward() to get
gradients for inputs (q, k, v, beta, g and initial_state/h0_vk/h0_kv as
applicable), and assert that corresponding gradients match (e.g., compare
q.grad, k.grad, v.grad and initial_state.grad vs tri's grads after appropriate
transpose) within the same tolerance used for forward (use assert_close on
gradients); ensure you .clone().detach().requires_grad_(True) the inputs before
calling the functions so gradients are tracked.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: d42dc09b-bf3d-46e0-946c-920a22ec96c9

📥 Commits

Reviewing files that changed from the base of the PR and between 034eda3 and 1116d52.

📒 Files selected for processing (16)
  • fla/ops/common/backends/intracard.py
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/common/chunk_o.py
  • fla/ops/common/intracard_cp.py
  • fla/ops/cp/chunk_delta_h.py
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/fused_recurrent.py
  • fla/ops/gla/chunk.py
  • fla/ops/kda/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_fwd.py
  • fla/ops/kda/fused_recurrent.py
  • tests/context_parallel/test_cp_gdn.py
  • tests/context_parallel/test_cp_kda.py
  • tests/ops/test_gated_delta.py
  • tests/ops/test_kda.py

Comment thread fla/ops/gated_delta_rule/chunk.py
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens_cpu: torch.LongTensor | None = None,
cp_context: FLACPContext | None = None,
transpose_state_layout: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate initial_state for the selected layout.

This flag changes the expected cache shape to [N, H, V, K], but the wrapper still only checks initial_state.shape[0]. Reusing an old [N, H, K, V] cache with transpose_state_layout=True will be silently misread by the Triton kernels and return wrong outputs/gradients instead of failing fast.

🩹 Proposed fix
     if cu_seqlens is not None:
         if q.shape[0] != 1:
             raise ValueError(
                 f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
                 f"Please flatten variable-length inputs before processing.",
             )
         if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
             raise ValueError(
                 f"The number of initial states is expected to be equal to the number of input sequences, "
                 f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.",
             )
+    if initial_state is not None:
+        expected_tail = (
+            (q.shape[2], v.shape[-1], k.shape[-1])
+            if transpose_state_layout
+            else (q.shape[2], k.shape[-1], v.shape[-1])
+        )
+        if initial_state.shape[1:] != expected_tail:
+            raise ValueError(
+                f"`initial_state` must have shape [N, {expected_tail[0]}, {expected_tail[1]}, {expected_tail[2]}], "
+                f"got {tuple(initial_state.shape)}."
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk.py` at line 313, The wrapper currently accepts
transpose_state_layout: bool and then only checks initial_state.shape[0], which
lets a mismatched layout ([N,H,K,V] vs expected [N,H,V,K]) slip through; update
the wrapper to validate initial_state.ndim and the ordering of dimensions when
transpose_state_layout is True by asserting (or raising a clear ValueError) that
initial_state has four dims and that its shape matches the expected [N, H, V, K]
layout (or the non-transposed [N, H, K, V] when False), referencing the
transpose_state_layout flag and initial_state parameter so misuse fails fast
rather than silently producing wrong outputs/gradients.

output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
transpose_state_layout: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate initial_state when transpose mode is enabled.

fused_recurrent_gated_delta_rule_fwd now flips the state layout to [N, HV, V, K], but it still accepts any initial_state tail shape. Because the kernel switches pointer math between [K, V] and [V, K], a mismatched cache will be silently misread rather than rejected.

Also applies to: 181-185

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/fused_recurrent.py` at line 171, When
transpose_state_layout is True in fused_recurrent_gated_delta_rule_fwd (and the
analogous call around lines 181-185), validate that initial_state has the
expected tail shape/layout ([N, HV, V, K]) before proceeding; detect mismatches
where initial_state is in [N, HV, K, V] (or any other incompatible tail shape)
and raise a clear ValueError indicating the expected vs actual tail dims.
Implement this check immediately after reading transpose_state_layout and before
any pointer math or kernel calls so the kernel never silently misreads a
mismatched cache; reference the transpose_state_layout parameter and the
initial_state tensor in your error message for clarity.

Comment thread fla/ops/kda/chunk.py
disable_recompute: bool = False,
return_intermediate_states: bool = False,
cp_context: FLACPContext = None,
transpose_state_layout: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject mismatched cache layouts at the KDA entrypoint.

transpose_state_layout is public now, but chunk_kda() still only validates the sequence count and dtype. Passing a [N, H, K, V] cache into transpose mode (or the inverse) will silently reinterpret memory in chunk_kda_fwd / chunk_kda_bwd and return incorrect activations and gradients.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/chunk.py` at line 163, chunk_kda currently only validates
sequence count and dtype but not the cache memory layout, so if
transpose_state_layout is toggled it can silently reinterpret memory in
chunk_kda_fwd / chunk_kda_bwd; update chunk_kda() to detect and reject
mismatched cache layouts by checking the cache tensor's shape/order against the
transpose_state_layout flag (e.g., if transpose_state_layout is False expect [N,
S, H, K, V] layout or if True expect [N, H, K, V] etc.), and raise a clear error
when the actual layout doesn't match the flag before calling chunk_kda_fwd /
chunk_kda_bwd; reference the transpose_state_layout parameter and the cache
input used by chunk_kda(), and add the guard early in chunk_kda() so downstream
fwd/bwd kernels never reinterpret memory incorrectly.

use_gate_in_kernel: bool = False,
lower_bound: float | None = None,
out: torch.Tensor | None = None,
transpose_state_layout: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on recurrent-state layout mismatches.

This path now allocates [N, HV, V, K] final states in transpose mode, but a provided initial_state is never checked against that layout. In both the inplace and regular paths, a stale [N, HV, K, V] cache will be silently reinterpreted by the kernel.

Also applies to: 259-262

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/fused_recurrent.py` at line 239, The code accepts
transpose_state_layout but never validates initial_state layout, so a stale [N,
HV, K, V] buffer can be silently reinterpreted as [N, HV, V, K]; add an explicit
fast-fail validation in both the inplace and regular paths (where
transpose_state_layout is used) that checks initial_state's dimensions/order
against the expected final layout when transpose_state_layout is True (expected
axes [N, HV, V, K]) and raise a clear ValueError if they mismatch;
alternatively, if safe, perform an explicit transpose/reorder of initial_state
into the required layout before caching/using it—apply this check/reorder for
the code paths around the transpose_state_layout flag and any functions that
consume initial_state so the kernel never receives a stale layout.

@zhiyuan1i zhiyuan1i force-pushed the lzy/transpose-kv branch 3 times, most recently from 23b84a2 to b574148 Compare March 10, 2026 04:21
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
fla/ops/common/chunk_o.py (1)

138-170: ⚠️ Potential issue | 🔴 Critical

Complete the transpose rollout for the dv path.

chunk_bwd_dqkwg is now layout-aware, but chunk_bwd_kernel_dv / chunk_bwd_dv still hard-code dh as (K, V) at Lines 379-381 and Lines 542-593. With transpose_state_layout=True, fla/ops/common/chunk_delta_h.py now produces dh in [V, K], so this branch will read the wrong strides and return incorrect dv.

Also applies to: 662-711

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/common/chunk_o.py` around lines 138 - 170, The dv-backward path still
assumes dh has layout (K, V) causing wrong strides when transpose_state_layout
is enabled; update the transpose rollout in the dv kernels so they honor the
layout flag. Specifically, in chunk_bwd_kernel_dv and the caller chunk_bwd_dv,
branch on the TRANSPOSE_STATE (or transpose_state_layout) constexpr and compute
dh strides/shape and the indexing into dh as [K, V] when false and [V, K] when
true (matching fla/ops/common/chunk_delta_h.py output); adjust any temporary
views/loads and the accumulation into dv accordingly so all stride calculations
and memory accesses use the correct dimension ordering under both layouts.
Ensure the same fix is applied to the other occurrence noted (the later
chunk_bwd_kernel/dv block) so both code paths mirror the layout-aware logic used
for dqkwg.
fla/ops/kda/chunk.py (1)

140-141: ⚠️ Potential issue | 🔴 Critical

Add missing gradient placeholder in backward return tuple.

ChunkKDAFunction.forward() takes 20 inputs after ctx, but backward() at lines 140-141 returns only 19 gradients. PyTorch will fail with an arity mismatch error. Add one more None gradient to the return tuple:

-        return (dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), dA, dbias, None, dh0,
-                None, None, None, None, None, None, None, None, None, None)
+        return (dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), dA, dbias, None, dh0,
+                None, None, None, None, None, None, None, None, None, None, None)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/chunk.py` around lines 140 - 141, The backward() in
ChunkKDAFunction returns 19 gradients but forward() accepted 20 inputs after
ctx; update ChunkKDAFunction.backward() return tuple (the line returning
(dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), dA, dbias, None, dh0,
...)) to include one additional None so the returned tuple has 20 entries to
match forward()'s inputs (i.e., append another None in the return tuple).
♻️ Duplicate comments (5)
fla/ops/gated_delta_rule/fused_recurrent.py (1)

171-187: ⚠️ Potential issue | 🟠 Major

Validate initial_state for the selected layout.

This path now allocates [N, HV, V, K] in transpose mode, but a provided initial_state is still never checked. A stale cache — or a non-contiguous transpose view — will be silently reinterpreted by the kernel.

🩹 Proposed fix
     o = torch.empty_like(v)
+    if initial_state is not None:
+        expected_shape = (
+            (N, HV, V, K)
+            if transpose_state_layout
+            else (N, HV, K, V)
+        )
+        if (
+            initial_state.ndim != 4
+            or tuple(initial_state.shape) != expected_shape
+            or not initial_state.is_contiguous()
+        ):
+            raise ValueError(
+                f"`initial_state` must be contiguous and have shape {expected_shape} when "
+                f"`transpose_state_layout={transpose_state_layout}`, got shape "
+                f"{tuple(initial_state.shape)} with strides {tuple(initial_state.stride())}."
+            )
     if output_final_state:
         if transpose_state_layout:
             final_state = q.new_empty(N, HV, V, K, dtype=torch.float32)
         else:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 171 - 187, The code
allocates final_state with different memory layouts when transpose_state_layout
is True (q.new_empty(N, HV, V, K)) vs False (q.new_empty(N, HV, K, V)) but never
validates a provided initial_state; update the logic in the function that
handles initial_state/ final_state (look for initial_state,
transpose_state_layout, output_final_state, final_state) to check that when
initial_state is not None its shape and memory layout match the expected layout
for the chosen transpose_state_layout (e.g., exact shape [N, HV, V, K] for
transpose_mode or [N, HV, K, V] otherwise), raise a clear error if mismatched,
and ensure any view/transpose is made contiguous (or copy to the expected
layout) before passing to downstream kernels.
fla/ops/gated_delta_rule/chunk.py (2)

291-296: ⚠️ Potential issue | 🔴 Critical

Add the missing backward slot for transpose_state_layout.

Line 296 still returns 12 gradients for a forward() with 13 inputs after ctx, so PyTorch will fail the first backward pass with a gradient-arity error.

Verification script
#!/bin/bash
python - <<'PY'
import ast
from pathlib import Path

path = Path("fla/ops/gated_delta_rule/chunk.py")
tree = ast.parse(path.read_text())

cls = next(n for n in tree.body if isinstance(n, ast.ClassDef) and n.name == "ChunkGatedDeltaRuleFunction")
fwd = next(n for n in cls.body if isinstance(n, ast.FunctionDef) and n.name == "forward")
bwd = next(n for n in cls.body if isinstance(n, ast.FunctionDef) and n.name == "backward")
ret = next(n for n in ast.walk(bwd) if isinstance(n, ast.Return))

print("forward inputs:", len(fwd.args.args) - 1)
print("backward outputs:", len(ret.value.elts))
# Expect these counts to match.
PY
🩹 Proposed fix
-        return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None
+        return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None, None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk.py` around lines 291 - 296, The backward()
return tuple is missing a gradient slot for the forward() input
transpose_state_layout, causing a mismatch in arity; update
ChunkGatedDeltaRuleFunction.backward to include the missing gradient position
(most likely a None if no gradient is required) corresponding to
transpose_state_layout so the number of returned gradients equals the number of
forward inputs after ctx; locate the return in backward() and insert the
appropriate None (or computed grad) at the position matching
transpose_state_layout.

313-418: ⚠️ Potential issue | 🟠 Major

Validate the state tensor contract before dispatch.

transpose_state_layout flips the expected cache layout, but this wrapper still only checks sequence count. A stale cache — or a non-contiguous initial_state.transpose(-1, -2) view — will be silently misread by the Triton kernels.

🩹 Proposed fix
     if cu_seqlens is not None:
         if q.shape[0] != 1:
             raise ValueError(
                 f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
                 f"Please flatten variable-length inputs before processing.",
             )
         if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
             raise ValueError(
                 f"The number of initial states is expected to be equal to the number of input sequences, "
                 f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.",
             )
+    if initial_state is not None:
+        expected_n = len(cu_seqlens) - 1 if cu_seqlens is not None else q.shape[0]
+        expected_shape = (
+            (expected_n, q.shape[2], v.shape[-1], k.shape[-1])
+            if transpose_state_layout
+            else (expected_n, q.shape[2], k.shape[-1], v.shape[-1])
+        )
+        if (
+            initial_state.ndim != 4
+            or tuple(initial_state.shape) != expected_shape
+            or not initial_state.is_contiguous()
+        ):
+            raise ValueError(
+                f"`initial_state` must be contiguous and have shape {expected_shape} when "
+                f"`transpose_state_layout={transpose_state_layout}`, got shape "
+                f"{tuple(initial_state.shape)} with strides {tuple(initial_state.stride())}."
+            )
     if scale is None:
         scale = k.shape[-1] ** -0.5
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk.py` around lines 313 - 418, The wrapper must
validate the initial_state tensor's layout and shape before calling
ChunkGatedDeltaRuleFunction.apply: check initial_state is a 4-D tensor with
initial_state.shape[0] == (len(cu_seqlens)-1 if cu_seqlens is given else
q.shape[0]), and that its last two dims match K and V taking
transpose_state_layout into account (i.e., expect [..., K, V] when
transpose_state_layout=False and [..., V, K] when True); also ensure the tensor
is contiguous in the memory layout the Triton kernel expects (use
.is_contiguous() or verify strides) and if it isn’t, make a contiguous copy or
explicitly transpose+contiguous so the dispatched kernel never receives a
non-contiguous/transposed view. Apply these checks/normalization right before
calling ChunkGatedDeltaRuleFunction.apply.
fla/ops/kda/chunk.py (1)

163-338: ⚠️ Potential issue | 🟠 Major

Validate the state tensor contract before calling chunk_kda_fwd.

transpose_state_layout changes the cache layout, but chunk_kda() still only checks dtype and sequence count. A stale cache — or a non-contiguous initial_state.transpose(-1, -2) view — will be silently reinterpreted downstream.

🩹 Proposed fix
     if cu_seqlens is not None:
         if q.shape[0] != 1:
             raise ValueError(
                 f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
                 f"Please flatten variable-length inputs before processing.",
@@
             )
     if initial_state is not None:
         assert initial_state.dtype == torch.float32, "initial_state must be in float32."
+        expected_n = len(cu_seqlens) - 1 if cu_seqlens is not None else q.shape[0]
+        expected_shape = (
+            (expected_n, q.shape[2], v.shape[-1], k.shape[-1])
+            if transpose_state_layout
+            else (expected_n, q.shape[2], k.shape[-1], v.shape[-1])
+        )
+        if (
+            initial_state.ndim != 4
+            or tuple(initial_state.shape) != expected_shape
+            or not initial_state.is_contiguous()
+        ):
+            raise ValueError(
+                f"`initial_state` must be contiguous and have shape {expected_shape} when "
+                f"`transpose_state_layout={transpose_state_layout}`, got shape "
+                f"{tuple(initial_state.shape)} with strides {tuple(initial_state.stride())}."
+            )
 
     A_log, dt_bias = None, None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/chunk.py` around lines 163 - 338, chunk_kda currently only checks
initial_state dtype and count but ignores layout changes caused by
transpose_state_layout, risking silent misinterpretation; before calling
ChunkKDAFunction.apply, validate initial_state more strictly: confirm dtype is
torch.float32, confirm initial_state.shape matches the expected layout ([N, H,
K, V] when transpose_state_layout=False or [N, H, V, K] when True), confirm the
first dimension equals len(cu_seqlens)-1 when cu_seqlens is provided, and ensure
the tensor is contiguous (or explicitly require/clone to contiguous) so a
transposed view cannot be passed through; perform these checks inside chunk_kda
(near the existing initial_state assertions) and raise descriptive
ValueError/assertion mentioning initial_state and transpose_state_layout if the
contract is violated.
fla/ops/kda/fused_recurrent.py (1)

239-262: ⚠️ Potential issue | 🟠 Major

Reject recurrent caches with the wrong tail layout.

This helper now produces [*, HV, V, K] caches in transpose mode, but a provided initial_state is still never validated. A stale cache — or a non-contiguous transpose view — will be silently reinterpreted in both the inplace and regular paths.

🩹 Proposed fix
     if out is None:
         out = torch.zeros_like(v)
     else:
         assert out.shape == v.shape
+    if initial_state is not None:
+        expected_tail = (
+            (HV, V, K)
+            if transpose_state_layout
+            else (HV, K, V)
+        )
+        if (
+            initial_state.ndim != 4
+            or tuple(initial_state.shape[1:]) != expected_tail
+            or not initial_state.is_contiguous()
+        ):
+            raise ValueError(
+                f"`initial_state` must be contiguous and end with {expected_tail} when "
+                f"`transpose_state_layout={transpose_state_layout}`, got shape "
+                f"{tuple(initial_state.shape)} with strides {tuple(initial_state.stride())}."
+            )
     if inplace_final_state:
         assert initial_state is not None
         final_state = initial_state
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/fused_recurrent.py` around lines 239 - 262, In
fused_recurrent.py, validate any provided initial_state against the expected
layout before reusing it: compute N and expected shape based on
transpose_state_layout (expected = (N, HV, V, K) when transpose_state_layout is
True, otherwise (N, HV, K, V)), then assert initial_state is not None implies
initial_state.shape == expected, initial_state.device == q.device and
initial_state.dtype == q.dtype (or float32 if q is cast), and that
initial_state.is_contiguous() (or use .contiguous() only after cloning) to avoid
silently reinterpreting a non-contiguous transpose view; raise a clear
ValueError/AssertionError if the checks fail, and in the non-inplace path
consider cloning/copying a validated initial_state into final_state instead of
reinterpreting it.
🧹 Nitpick comments (1)
tests/context_parallel/test_cp_gdn.py (1)

93-103: Use a rectangular K/V case in the transpose-state suite.

The harness still takes a single D, so every transpose-state run here has K == V. That masks exactly the [K, V] vs [V, K] mix-ups this flag is meant to catch; please split the helper into separate K/V sizes and add at least one K != V transpose case.

Also applies to: 275-295, 390-417

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/context_parallel/test_cp_gdn.py` around lines 93 - 103, The helper
run_cp_gdn_test_worker currently takes a single D so K==V for all
transpose-state tests; change its signature to accept separate K and V sizes
(e.g., add parameters K: int, V: int) and update its internal construction of
key/value tensors and any shapes that used D to use K and V respectively; update
all call sites (including the other two helper occurrences noted) to pass
distinct K and V where you want a rectangular case and add at least one test
invocation with K != V to the transpose-state suite so the [K,V] vs [V,K] mix-up
is exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/ops/test_gated_delta.py`:
- Around line 190-192: The test uses all-zero initial states h0_kv and h0_vk
which masks transpose_state_layout bugs; change the initialization of h0_kv (and
its transposed h0_vk) to a non-zero, deterministic tensor (e.g., seeded
torch.randn or a pattern like arange/ones scaled by indices) so that h0_kv != 0
and h0_vk = h0_kv.transpose(-1,-2).contiguous() and still pass both through the
existing map that sets .to(device).requires_grad_(True); update the lines
creating h0_kv/h0_vk and their subsequent mapping to ensure the forward path
detects incorrect transposed loads.

In `@tests/ops/test_kda.py`:
- Around line 163-171: This transpose-only test is missing the Intel Alchemist
skip that test_fused_recurrent uses; add the same guard so the test calls
pytest.skip when running on Alchemist with D > 128. Locate the transpose test in
tests/ops/test_kda.py (same scope as the q,k,v,g,beta,h0_kv,h0_vk setup), import
or reuse the existing is_alchemist() helper and add a conditional like `if
is_alchemist() and D > 128: pytest.skip(...)` before creating tensors so the
test mirrors test_fused_recurrent's behavior.

---

Outside diff comments:
In `@fla/ops/common/chunk_o.py`:
- Around line 138-170: The dv-backward path still assumes dh has layout (K, V)
causing wrong strides when transpose_state_layout is enabled; update the
transpose rollout in the dv kernels so they honor the layout flag. Specifically,
in chunk_bwd_kernel_dv and the caller chunk_bwd_dv, branch on the
TRANSPOSE_STATE (or transpose_state_layout) constexpr and compute dh
strides/shape and the indexing into dh as [K, V] when false and [V, K] when true
(matching fla/ops/common/chunk_delta_h.py output); adjust any temporary
views/loads and the accumulation into dv accordingly so all stride calculations
and memory accesses use the correct dimension ordering under both layouts.
Ensure the same fix is applied to the other occurrence noted (the later
chunk_bwd_kernel/dv block) so both code paths mirror the layout-aware logic used
for dqkwg.

In `@fla/ops/kda/chunk.py`:
- Around line 140-141: The backward() in ChunkKDAFunction returns 19 gradients
but forward() accepted 20 inputs after ctx; update ChunkKDAFunction.backward()
return tuple (the line returning (dq.to(q), dk.to(k), dv.to(v), dg.to(g),
db.to(beta), dA, dbias, None, dh0, ...)) to include one additional None so the
returned tuple has 20 entries to match forward()'s inputs (i.e., append another
None in the return tuple).

---

Duplicate comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 291-296: The backward() return tuple is missing a gradient slot
for the forward() input transpose_state_layout, causing a mismatch in arity;
update ChunkGatedDeltaRuleFunction.backward to include the missing gradient
position (most likely a None if no gradient is required) corresponding to
transpose_state_layout so the number of returned gradients equals the number of
forward inputs after ctx; locate the return in backward() and insert the
appropriate None (or computed grad) at the position matching
transpose_state_layout.
- Around line 313-418: The wrapper must validate the initial_state tensor's
layout and shape before calling ChunkGatedDeltaRuleFunction.apply: check
initial_state is a 4-D tensor with initial_state.shape[0] == (len(cu_seqlens)-1
if cu_seqlens is given else q.shape[0]), and that its last two dims match K and
V taking transpose_state_layout into account (i.e., expect [..., K, V] when
transpose_state_layout=False and [..., V, K] when True); also ensure the tensor
is contiguous in the memory layout the Triton kernel expects (use
.is_contiguous() or verify strides) and if it isn’t, make a contiguous copy or
explicitly transpose+contiguous so the dispatched kernel never receives a
non-contiguous/transposed view. Apply these checks/normalization right before
calling ChunkGatedDeltaRuleFunction.apply.

In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 171-187: The code allocates final_state with different memory
layouts when transpose_state_layout is True (q.new_empty(N, HV, V, K)) vs False
(q.new_empty(N, HV, K, V)) but never validates a provided initial_state; update
the logic in the function that handles initial_state/ final_state (look for
initial_state, transpose_state_layout, output_final_state, final_state) to check
that when initial_state is not None its shape and memory layout match the
expected layout for the chosen transpose_state_layout (e.g., exact shape [N, HV,
V, K] for transpose_mode or [N, HV, K, V] otherwise), raise a clear error if
mismatched, and ensure any view/transpose is made contiguous (or copy to the
expected layout) before passing to downstream kernels.

In `@fla/ops/kda/chunk.py`:
- Around line 163-338: chunk_kda currently only checks initial_state dtype and
count but ignores layout changes caused by transpose_state_layout, risking
silent misinterpretation; before calling ChunkKDAFunction.apply, validate
initial_state more strictly: confirm dtype is torch.float32, confirm
initial_state.shape matches the expected layout ([N, H, K, V] when
transpose_state_layout=False or [N, H, V, K] when True), confirm the first
dimension equals len(cu_seqlens)-1 when cu_seqlens is provided, and ensure the
tensor is contiguous (or explicitly require/clone to contiguous) so a transposed
view cannot be passed through; perform these checks inside chunk_kda (near the
existing initial_state assertions) and raise descriptive ValueError/assertion
mentioning initial_state and transpose_state_layout if the contract is violated.

In `@fla/ops/kda/fused_recurrent.py`:
- Around line 239-262: In fused_recurrent.py, validate any provided
initial_state against the expected layout before reusing it: compute N and
expected shape based on transpose_state_layout (expected = (N, HV, V, K) when
transpose_state_layout is True, otherwise (N, HV, K, V)), then assert
initial_state is not None implies initial_state.shape == expected,
initial_state.device == q.device and initial_state.dtype == q.dtype (or float32
if q is cast), and that initial_state.is_contiguous() (or use .contiguous() only
after cloning) to avoid silently reinterpreting a non-contiguous transpose view;
raise a clear ValueError/AssertionError if the checks fail, and in the
non-inplace path consider cloning/copying a validated initial_state into
final_state instead of reinterpreting it.

---

Nitpick comments:
In `@tests/context_parallel/test_cp_gdn.py`:
- Around line 93-103: The helper run_cp_gdn_test_worker currently takes a single
D so K==V for all transpose-state tests; change its signature to accept separate
K and V sizes (e.g., add parameters K: int, V: int) and update its internal
construction of key/value tensors and any shapes that used D to use K and V
respectively; update all call sites (including the other two helper occurrences
noted) to pass distinct K and V where you want a rectangular case and add at
least one test invocation with K != V to the transpose-state suite so the [K,V]
vs [V,K] mix-up is exercised.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a9fac52e-3dca-4b1e-a740-a94e23338e3f

📥 Commits

Reviewing files that changed from the base of the PR and between 1116d52 and b574148.

📒 Files selected for processing (16)
  • fla/ops/common/backends/intracard.py
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/common/chunk_o.py
  • fla/ops/common/intracard_cp.py
  • fla/ops/cp/chunk_delta_h.py
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/fused_recurrent.py
  • fla/ops/gla/chunk.py
  • fla/ops/kda/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_fwd.py
  • fla/ops/kda/fused_recurrent.py
  • tests/context_parallel/test_cp_gdn.py
  • tests/context_parallel/test_cp_kda.py
  • tests/ops/test_gated_delta.py
  • tests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/context_parallel/test_cp_kda.py

Comment thread tests/ops/test_gated_delta.py Outdated
Comment thread tests/ops/test_kda.py
Comment on lines +163 to +171
torch.manual_seed(42)
q = torch.rand(B, T, H, D, dtype=dtype)
k = torch.rand(B, T, H, D, dtype=dtype)
v = torch.rand(B, T, H, D, dtype=dtype)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=torch.float)) / gate_logit_normalizer
beta = torch.randn(B, T, H, dtype=dtype).sigmoid()
h0_kv = torch.randn(B, H, D, D, dtype=torch.float32)
h0_vk = h0_kv.transpose(-1, -2).contiguous()
q, k, v, g, beta, h0_kv, h0_vk = map(lambda x: x.to(device), (q, k, v, g, beta, h0_kv, h0_vk))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Mirror the existing Alchemist guard in this transpose test.

test_fused_recurrent already skips D > 128 on Intel Alchemist, but this new variant drops that guard. That makes the transpose-only test fail on a known unsupported backend instead of exercising the new layout path.

Suggested fix
 def test_fused_recurrent_transpose_state(
     B: int,
     T: int,
     H: int,
     D: int,
     scale: float,
     gate_logit_normalizer: float,
     dtype: torch.dtype,
 ):
     torch.manual_seed(42)
+    if IS_INTEL_ALCHEMIST and D > 128:
+        pytest.skip(reason="fused_recurrent_kda is not supported on alchemist for D>128")
     q = torch.rand(B, T, H, D, dtype=dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
torch.manual_seed(42)
q = torch.rand(B, T, H, D, dtype=dtype)
k = torch.rand(B, T, H, D, dtype=dtype)
v = torch.rand(B, T, H, D, dtype=dtype)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=torch.float)) / gate_logit_normalizer
beta = torch.randn(B, T, H, dtype=dtype).sigmoid()
h0_kv = torch.randn(B, H, D, D, dtype=torch.float32)
h0_vk = h0_kv.transpose(-1, -2).contiguous()
q, k, v, g, beta, h0_kv, h0_vk = map(lambda x: x.to(device), (q, k, v, g, beta, h0_kv, h0_vk))
torch.manual_seed(42)
if IS_INTEL_ALCHEMIST and D > 128:
pytest.skip(reason="fused_recurrent_kda is not supported on alchemist for D>128")
q = torch.rand(B, T, H, D, dtype=dtype)
k = torch.rand(B, T, H, D, dtype=dtype)
v = torch.rand(B, T, H, D, dtype=dtype)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=torch.float)) / gate_logit_normalizer
beta = torch.randn(B, T, H, dtype=dtype).sigmoid()
h0_kv = torch.randn(B, H, D, D, dtype=torch.float32)
h0_vk = h0_kv.transpose(-1, -2).contiguous()
q, k, v, g, beta, h0_kv, h0_vk = map(lambda x: x.to(device), (q, k, v, g, beta, h0_kv, h0_vk))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_kda.py` around lines 163 - 171, This transpose-only test is
missing the Intel Alchemist skip that test_fused_recurrent uses; add the same
guard so the test calls pytest.skip when running on Alchemist with D > 128.
Locate the transpose test in tests/ops/test_kda.py (same scope as the
q,k,v,g,beta,h0_kv,h0_vk setup), import or reuse the existing is_alchemist()
helper and add a conditional like `if is_alchemist() and D > 128:
pytest.skip(...)` before creating tensors so the test mirrors
test_fused_recurrent's behavior.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
fla/ops/common/chunk_delta_h.py (1)

727-748: ⚠️ Potential issue | 🟠 Major

Use chunk_size for BT here.

Line 734 still hardcodes BT = 64. If the forward pass used any other chunk_size, this backward wrapper will derive a different NT and state layout than the tensors it is consuming.

🐛 Proposed fix
-    BT = 64
+    BT = chunk_size

If 64 is the only supported value, assert that explicitly and drop the parameter instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/common/chunk_delta_h.py` around lines 727 - 748, The code currently
hardcodes BT = 64 which can mismatch the provided chunk_size and lead to
incorrect NT and state layouts; replace the hardcoded BT with BT = chunk_size
(or if only 64 is supported, assert chunk_size == 64 and remove the parameter)
so NT and chunk_offsets are computed consistently; update any dependent logic
that uses BT (e.g., prepare_chunk_offsets(cu_seqlens, BT), NT computation, dh
shape creation using NT, and dh0) to rely on the corrected BT value and keep the
existing transpose_state_layout, dh, dh0, cu_seqlens, chunk_indices handling
unchanged.
fla/ops/kda/fused_recurrent.py (1)

144-167: ⚠️ Potential issue | 🔴 Critical

Mask g loads on padded K lanes.

BK is rounded up with triton.next_power_of_2(K), but p_g is loaded without mask_k at line 148. For non-power-of-two head sizes, this reads past the end of each g row, and those unmasked padded lanes flow directly into exp(b_gk) and corrupt the hidden state b_h via the multiplication at lines 157–159.

Proposed fix
-        b_g = tl.load(p_g, eviction_policy='evict_last').to(tl.float32)
+        b_g = tl.load(p_g, mask=mask_k, other=0, eviction_policy='evict_last').to(tl.float32)
♻️ Duplicate comments (4)
tests/ops/test_gated_delta.py (1)

190-192: ⚠️ Potential issue | 🟡 Minor

Use non-zero initial state to properly test transpose layout.

h0_kv is initialized with torch.zeros, which masks layout bugs since both correct and incorrect transposed loads produce the same result (zeros). Use torch.randn like test_fused_recurrent_transpose_state does at line 266.

Suggested fix
-    h0_kv = torch.zeros(B, H, D, D, dtype=torch.float32)
+    h0_kv = torch.randn(B, H, D, D, dtype=torch.float32)
     h0_vk = h0_kv.transpose(-1, -2).contiguous()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_gated_delta.py` around lines 190 - 192, The initial state
h0_kv is incorrectly initialized with zeros which hides transpose/layout bugs;
replace its initialization with a non-zero random tensor (use torch.randn with
same shape and dtype) so h0_kv = torch.randn(B, H, D, D, dtype=torch.float32),
then compute h0_vk = h0_kv.transpose(-1, -2).contiguous() and keep the existing
map(... .to(device).requires_grad_(True)) call for (q, k, v, beta, g, h0_kv,
h0_vk) to ensure gradients and device placement remain the same.
tests/ops/test_kda.py (1)

163-171: ⚠️ Potential issue | 🟡 Minor

Add Intel Alchemist skip guard for consistency.

test_fused_recurrent at line 102-103 skips when IS_INTEL_ALCHEMIST and D > 128, but this transpose test is missing that guard. This could cause failures on Alchemist hardware instead of exercising the transpose layout path.

Suggested fix
 def test_fused_recurrent_transpose_state(
     B: int,
     T: int,
     H: int,
     D: int,
     scale: float,
     gate_logit_normalizer: float,
     dtype: torch.dtype,
 ):
     torch.manual_seed(42)
+    if IS_INTEL_ALCHEMIST and D > 128:
+        pytest.skip(reason="fused_recurrent_kda is not supported on alchemist for D>128")
     q = torch.rand(B, T, H, D, dtype=dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_kda.py` around lines 163 - 171, Add the same Intel Alchemist
skip guard used in test_fused_recurrent to this transpose test: check
IS_INTEL_ALCHEMIST and D > 128 at the start of the test (before creating
q,k,v,g,beta,h0_kv/h0_vk) and call pytest.skip(...) when true so the
transpose-layout path is not executed on Alchemist hardware; reference the
symbols IS_INTEL_ALCHEMIST and D and mirror the skip condition/behavior from
test_fused_recurrent.
fla/ops/gated_delta_rule/fused_recurrent.py (1)

171-187: ⚠️ Potential issue | 🟠 Major

At least validate initial_state against the selected state shape.

transpose_state_layout now changes the state contract here, but this wrapper still accepts any initial_state. The public entry point only validates the sequence count, so incompatible HV or tail dims can still reach the kernel's hard-coded pointer math and be silently misread.

🛡️ Proposed fix
     B, T, H, K, V = *k.shape, v.shape[-1]
     HV = v.shape[2]
     N = B if cu_seqlens is None else len(cu_seqlens) - 1
+    if initial_state is not None:
+        expected_state_shape = (N, HV, V, K) if transpose_state_layout else (N, HV, K, V)
+        if tuple(initial_state.shape) != expected_state_shape:
+            raise ValueError(
+                f"`initial_state` must have shape {expected_state_shape} when "
+                f"`transpose_state_layout={transpose_state_layout}`; got {tuple(initial_state.shape)}."
+            )
     BK = triton.next_power_of_2(K)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 171 - 187, The
wrapper must validate any provided initial_state against the expected layout
determined by transpose_state_layout before using it: in the function
(fused_recurrent / where transpose_state_layout, HV, K, V, N are computed) check
that when initial_state is not None its shape equals (N, HV, V, K) if
transpose_state_layout is True, otherwise equals (N, HV, K, V), and raise a
clear ValueError if mismatched (include expected vs actual shape); also validate
its dtype/device are compatible with q/v and that the leading dimension N
matches len(cu_seqlens)-1 when cu_seqlens is provided. Ensure this validation
occurs before any pointer/stride math or kernel dispatch that assumes the state
layout.
fla/ops/kda/fused_recurrent.py (1)

239-262: ⚠️ Potential issue | 🟠 Major

At least validate initial_state against the selected state shape.

This wrapper now flips the state contract between [N, HV, K, V] and [N, HV, V, K], but it still forwards any initial_state straight into the kernel and, in the inplace path, reuses it as final_state. The varlen entry point only checks shape[0], so incompatible HV or tail dims can still be silently misread.

🛡️ Proposed fix
     B, T, H, K, V = *k.shape, v.shape[-1]
     HV = v.shape[2]
     N = B if cu_seqlens is None else len(cu_seqlens) - 1
+    if initial_state is not None:
+        expected_state_shape = (N, HV, V, K) if transpose_state_layout else (N, HV, K, V)
+        if tuple(initial_state.shape) != expected_state_shape:
+            raise ValueError(
+                f"`initial_state` must have shape {expected_state_shape} when "
+                f"`transpose_state_layout={transpose_state_layout}`; got {tuple(initial_state.shape)}."
+            )
     BK = triton.next_power_of_2(K)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/fused_recurrent.py` around lines 239 - 262, The wrapper flips
state layout between [N, HV, K, V] and [N, HV, V, K] based on
transpose_state_layout but never validates initial_state against the selected
layout; add explicit shape validation where initial_state is accepted or reused
(symbols: initial_state, final_state, transpose_state_layout,
inplace_final_state, output_final_state) — check that initial_state.ndim and
each dimension (N, HV, K, V or V,K order) match the expected shape and dtype,
and raise a clear ValueError/Assertion if they don’t; also when
inplace_final_state is true, ensure final_state (== initial_state) exactly
matches the expected layout before using it, and perform the same validation for
any varlen entry-point that previously only checked shape[0] so HV and tail
dimensions cannot be silently misread.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 114-126: The loads for p_gk and p_gv are not masked, so padded
lanes read garbage and corrupt b_h via exp(...) multiplications; update the
tl.load calls that produce b_gk and b_gv to use the same padding mask used for
q/k/v/beta loads (i.e., pass the mask and mask_fill value) so out-of-range lanes
are zeroed before computing exp and multiplying into b_h, keeping the existing
TRANSPOSE_STATE branching and use of USE_GK/USE_GV and symbols p_gk, p_gv, b_gk,
b_gv, b_h unchanged.

---

Outside diff comments:
In `@fla/ops/common/chunk_delta_h.py`:
- Around line 727-748: The code currently hardcodes BT = 64 which can mismatch
the provided chunk_size and lead to incorrect NT and state layouts; replace the
hardcoded BT with BT = chunk_size (or if only 64 is supported, assert chunk_size
== 64 and remove the parameter) so NT and chunk_offsets are computed
consistently; update any dependent logic that uses BT (e.g.,
prepare_chunk_offsets(cu_seqlens, BT), NT computation, dh shape creation using
NT, and dh0) to rely on the corrected BT value and keep the existing
transpose_state_layout, dh, dh0, cu_seqlens, chunk_indices handling unchanged.

---

Duplicate comments:
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 171-187: The wrapper must validate any provided initial_state
against the expected layout determined by transpose_state_layout before using
it: in the function (fused_recurrent / where transpose_state_layout, HV, K, V, N
are computed) check that when initial_state is not None its shape equals (N, HV,
V, K) if transpose_state_layout is True, otherwise equals (N, HV, K, V), and
raise a clear ValueError if mismatched (include expected vs actual shape); also
validate its dtype/device are compatible with q/v and that the leading dimension
N matches len(cu_seqlens)-1 when cu_seqlens is provided. Ensure this validation
occurs before any pointer/stride math or kernel dispatch that assumes the state
layout.

In `@fla/ops/kda/fused_recurrent.py`:
- Around line 239-262: The wrapper flips state layout between [N, HV, K, V] and
[N, HV, V, K] based on transpose_state_layout but never validates initial_state
against the selected layout; add explicit shape validation where initial_state
is accepted or reused (symbols: initial_state, final_state,
transpose_state_layout, inplace_final_state, output_final_state) — check that
initial_state.ndim and each dimension (N, HV, K, V or V,K order) match the
expected shape and dtype, and raise a clear ValueError/Assertion if they don’t;
also when inplace_final_state is true, ensure final_state (== initial_state)
exactly matches the expected layout before using it, and perform the same
validation for any varlen entry-point that previously only checked shape[0] so
HV and tail dimensions cannot be silently misread.

In `@tests/ops/test_gated_delta.py`:
- Around line 190-192: The initial state h0_kv is incorrectly initialized with
zeros which hides transpose/layout bugs; replace its initialization with a
non-zero random tensor (use torch.randn with same shape and dtype) so h0_kv =
torch.randn(B, H, D, D, dtype=torch.float32), then compute h0_vk =
h0_kv.transpose(-1, -2).contiguous() and keep the existing map(...
.to(device).requires_grad_(True)) call for (q, k, v, beta, g, h0_kv, h0_vk) to
ensure gradients and device placement remain the same.

In `@tests/ops/test_kda.py`:
- Around line 163-171: Add the same Intel Alchemist skip guard used in
test_fused_recurrent to this transpose test: check IS_INTEL_ALCHEMIST and D >
128 at the start of the test (before creating q,k,v,g,beta,h0_kv/h0_vk) and call
pytest.skip(...) when true so the transpose-layout path is not executed on
Alchemist hardware; reference the symbols IS_INTEL_ALCHEMIST and D and mirror
the skip condition/behavior from test_fused_recurrent.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 264350ca-243a-4ce2-8571-444813ef9859

📥 Commits

Reviewing files that changed from the base of the PR and between b574148 and d4b4909.

📒 Files selected for processing (16)
  • fla/ops/common/backends/intracard.py
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/common/chunk_o.py
  • fla/ops/common/intracard_cp.py
  • fla/ops/cp/chunk_delta_h.py
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/fused_recurrent.py
  • fla/ops/gla/chunk.py
  • fla/ops/kda/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_fwd.py
  • fla/ops/kda/fused_recurrent.py
  • tests/context_parallel/test_cp_gdn.py
  • tests/context_parallel/test_cp_kda.py
  • tests/ops/test_gated_delta.py
  • tests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/context_parallel/test_cp_kda.py
  • fla/ops/kda/chunk.py

Comment on lines 114 to +126
if USE_GK:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
if TRANSPOSE_STATE:
b_h *= exp(b_gk[None, :])
else:
b_h *= exp(b_gk[:, None])

if USE_GV:
b_gv = tl.load(p_gv).to(tl.float32)
b_h *= exp(b_gv[None, :])

b_v = b_beta * (b_v - tl.sum(b_h * b_k[:, None], 0))
b_h += b_k[:, None] * b_v

# [BV]
b_o = tl.sum(b_h * b_q[:, None], 0)
if TRANSPOSE_STATE:
b_h *= exp(b_gv[:, None])
else:
b_h *= exp(b_gv[None, :])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

head -c 1000 "fla/ops/gated_delta_rule/fused_recurrent.py" && echo -e "\n...[file exists]"

Repository: fla-org/flash-linear-attention

Length of output: 1091


🏁 Script executed:

wc -l fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 121


🏁 Script executed:

sed -n '100,140p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 1459


🏁 Script executed:

sed -n '170,195p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 906


🏁 Script executed:

# Search for mask_k and mask_v in the kernel to see if they are defined
rg -n "mask_k|mask_v" fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 578


🏁 Script executed:

# Check the broader kernel signature and how other loads are masked
sed -n '1,100p' fla/ops/gated_delta_rule/fused_recurrent.py | grep -n "def \|@triton\|mask\|load"

Repository: fla-org/flash-linear-attention

Length of output: 721


🏁 Script executed:

# Find where p_gk and p_gv pointers are initialized in the kernel
sed -n '50,90p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 1260


🏁 Script executed:

# Check how gk and gv dimensions are used and whether they are padded or not
sed -n '60,85p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 785


🏁 Script executed:

# Look for pointer arithmetic with gk and gv to understand their shape
rg -n "p_gk|p_gv|gk\[|gv\[" fla/ops/gated_delta_rule/fused_recurrent.py | head -20

Repository: fla-org/flash-linear-attention

Length of output: 521


🏁 Script executed:

# Search for how g, gk, gv dimensions are computed and passed to the kernel
sed -n '200,250p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 1468


🏁 Script executed:

# Look at the full kernel function signature to see gk, gv parameter types
sed -n '20,75p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 1495


🏁 Script executed:

# Check if there are any input shape validations that enforce K/V to be powers of two
sed -n '170,200p' fla/ops/gated_delta_rule/fused_recurrent.py | grep -A 10 "def fused_recurrent_gated_delta_rule_fwd"

Repository: fla-org/flash-linear-attention

Length of output: 56


🏁 Script executed:

# Look for any assertions or constraints on K/V dimensions
rg -n "assert|constraint|power_of_2" fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 204


🏁 Script executed:

# Check the actual load patterns - confirm they use mask or don't use mask
sed -n '95,130p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 1346


🏁 Script executed:

# Verify the pointer arithmetic for all loads to ensure consistency
sed -n '65,75p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 416


🏁 Script executed:

# Look at how b_g is loaded to confirm it's a scalar (no K dimension)
sed -n '110,120p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 368


🏁 Script executed:

# Double-check the full context of pointer initialization for all gate tensors
sed -n '65,85p' fla/ops/gated_delta_rule/fused_recurrent.py

Repository: fla-org/flash-linear-attention

Length of output: 639


Mask gk and gv loads on padded lanes.

When K or V are not powers of two, BK and BV are rounded up, but p_gk and p_gv are loaded without masking. This reads garbage values into the padded lanes, which then corrupt the state computation through the exp(...) multiplications. All other dimension-dependent loads (q, k, v, beta) correctly use masks; gk and gv must too.

🐛 Proposed fix
-            b_gk = tl.load(p_gk).to(tl.float32)
+            b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
-            b_gv = tl.load(p_gv).to(tl.float32)
+            b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if USE_GK:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
if TRANSPOSE_STATE:
b_h *= exp(b_gk[None, :])
else:
b_h *= exp(b_gk[:, None])
if USE_GV:
b_gv = tl.load(p_gv).to(tl.float32)
b_h *= exp(b_gv[None, :])
b_v = b_beta * (b_v - tl.sum(b_h * b_k[:, None], 0))
b_h += b_k[:, None] * b_v
# [BV]
b_o = tl.sum(b_h * b_q[:, None], 0)
if TRANSPOSE_STATE:
b_h *= exp(b_gv[:, None])
else:
b_h *= exp(b_gv[None, :])
if USE_GK:
b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
if TRANSPOSE_STATE:
b_h *= exp(b_gk[None, :])
else:
b_h *= exp(b_gk[:, None])
if USE_GV:
b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
if TRANSPOSE_STATE:
b_h *= exp(b_gv[:, None])
else:
b_h *= exp(b_gv[None, :])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 114 - 126, The
loads for p_gk and p_gv are not masked, so padded lanes read garbage and corrupt
b_h via exp(...) multiplications; update the tl.load calls that produce b_gk and
b_gv to use the same padding mask used for q/k/v/beta loads (i.e., pass the mask
and mask_fill value) so out-of-range lanes are zeroed before computing exp and
multiplying into b_h, keeping the existing TRANSPOSE_STATE branching and use of
USE_GK/USE_GV and symbols p_gk, p_gv, b_gk, b_gv, b_h unchanged.

Add transpose_state_layout parameter to chunk, fused_recurrent, and
context parallel paths for both KDA and GDN. When enabled, all state
tensors use [V,K] layout instead of [K,V] to improve memory access
patterns.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (5)
fla/ops/gated_delta_rule/chunk.py (1)

313-313: ⚠️ Potential issue | 🟠 Major

Validate initial_state against transpose_state_layout.

This flag changes the expected cache tail from [H, K, V] to [H, V, K], but the wrapper still only checks the batch dimension. Passing an old [N, H, K, V] state with transpose_state_layout=True will be silently misread by the Triton kernels.

🩹 Suggested fix
     if cu_seqlens is not None:
         if q.shape[0] != 1:
             raise ValueError(
                 f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
                 f"Please flatten variable-length inputs before processing.",
             )
         if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
             raise ValueError(
                 f"The number of initial states is expected to be equal to the number of input sequences, "
                 f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.",
             )
+    if initial_state is not None:
+        expected_tail = (
+            (q.shape[2], v.shape[-1], k.shape[-1])
+            if transpose_state_layout
+            else (q.shape[2], k.shape[-1], v.shape[-1])
+        )
+        if initial_state.ndim != 4 or tuple(initial_state.shape[1:]) != expected_tail:
+            raise ValueError(
+                f"`initial_state` must have shape [N, {expected_tail[0]}, {expected_tail[1]}, {expected_tail[2]}] "
+                f"when transpose_state_layout={transpose_state_layout}, got {tuple(initial_state.shape)}."
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/chunk.py` at line 313, The wrapper must validate
that the incoming initial_state shape matches transpose_state_layout: when
transpose_state_layout is False expect cache tail layout [N, H, K, V] and when
True expect [N, H, V, K]; update the code that handles the initial_state
(referencing the transpose_state_layout flag and the initial_state
variable/parameter in the gated-delta/chunk wrapper) to check the dimensionality
and the order of the last two dims and raise a clear error if they mismatch
(include expected vs actual shapes in the message) so an old [N, H, K, V] passed
with transpose_state_layout=True is detected rather than silently misread by the
Triton kernels.
fla/ops/gated_delta_rule/fused_recurrent.py (2)

171-185: ⚠️ Potential issue | 🟠 Major

Reject mismatched initial_state layouts in transpose mode.

This path now allocates [N, HV, V, K] final states when transpose_state_layout=True, but it still accepts any initial_state tail shape. A stale [N, HV, K, V] cache will be silently misread by the kernel.

🩹 Suggested fix
     B, T, H, K, V = *k.shape, v.shape[-1]
     HV = v.shape[2]
     N = B if cu_seqlens is None else len(cu_seqlens) - 1
     BK = triton.next_power_of_2(K)
     BV = min(8, triton.next_power_of_2(V)) if gv is None else triton.next_power_of_2(V)
     NV = triton.cdiv(V, BV)
+
+    if initial_state is not None:
+        expected_tail = (HV, V, K) if transpose_state_layout else (HV, K, V)
+        if initial_state.ndim != 4 or tuple(initial_state.shape[1:]) != expected_tail:
+            raise ValueError(
+                f"`initial_state` must have shape [N, {expected_tail[0]}, {expected_tail[1]}, {expected_tail[2]}] "
+                f"when transpose_state_layout={transpose_state_layout}, got {tuple(initial_state.shape)}."
+            )
 
     o = torch.empty_like(v)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 171 - 185, When
transpose_state_layout=True, validate that any provided initial_state has the
transposed tail layout [N, HV, V, K] (not [N, HV, K, V]) and raise a clear
ValueError if it does not; locate the logic around transpose_state_layout,
initial_state, and final_state in fused_recurrent.py (the block that allocates
final_state when output_final_state is true) and add a shape/layout check before
using or allocating final_state so a stale [N, HV, K, V] cache is rejected
rather than silently misread by the kernel.

114-126: ⚠️ Potential issue | 🔴 Critical

Mask gk/gv loads on padded lanes.

BK and BV round K/V up, but p_gk and p_gv are still loaded without masks. On non-power-of-two heads, the padded lanes feed garbage into the exp(...) multipliers and corrupt the state update.

🩹 Suggested fix
         if USE_GK:
-            b_gk = tl.load(p_gk).to(tl.float32)
+            b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
             if TRANSPOSE_STATE:
                 b_h *= exp(b_gk[None, :])
             else:
                 b_h *= exp(b_gk[:, None])
 
         if USE_GV:
-            b_gv = tl.load(p_gv).to(tl.float32)
+            b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
             if TRANSPOSE_STATE:
                 b_h *= exp(b_gv[:, None])
             else:
                 b_h *= exp(b_gv[None, :])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 114 - 126, The
loads for p_gk/p_gv should be masked to avoid reading padded lanes (which BK/BV
round up) and injecting garbage into exp multipliers; modify the p_gk and p_gv
loads so that tl.load (or equivalent) is passed a mask that zeros out indices
beyond the true head size (or load full then set masked entries to 0), and
ensure the mask shape matches the TRANSPOSE_STATE branching (i.e., shape the
mask as [None, :] vs [:, None] to match how b_gk/b_gv are broadcast into b_h);
apply this for both USE_GK (b_gk from p_gk) and USE_GV (b_gv from p_gv) before
calling exp and multiplying into b_h.
fla/ops/kda/fused_recurrent.py (1)

239-262: ⚠️ Potential issue | 🟠 Major

Fail fast on recurrent-state layout mismatches.

When transpose_state_layout=True, this path switches the kernel and allocated state to [N, HV, V, K], but a provided initial_state is still accepted unchecked. Reusing an older [N, HV, K, V] cache will be silently reinterpreted.

🩹 Suggested fix
     B, T, H, K, V = *k.shape, v.shape[-1]
     HV = v.shape[2]
     N = B if cu_seqlens is None else len(cu_seqlens) - 1
     BK = triton.next_power_of_2(K)
     BV = 32
+
+    if initial_state is not None:
+        expected_tail = (HV, V, K) if transpose_state_layout else (HV, K, V)
+        if initial_state.ndim != 4 or tuple(initial_state.shape[1:]) != expected_tail:
+            raise ValueError(
+                f"`initial_state` must have shape [N, {expected_tail[0]}, {expected_tail[1]}, {expected_tail[2]}] "
+                f"when transpose_state_layout={transpose_state_layout}, got {tuple(initial_state.shape)}."
+            )
 
     if out is None:
         out = torch.zeros_like(v)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fla/ops/kda/fused_recurrent.py` around lines 239 - 262, The code allows an
initial_state whose memory layout doesn't match transpose_state_layout, causing
silent reinterpretation; add a fast-fail shape check before using initial_state:
compute N, HV, K, V as in the function and if initial_state is not None assert
(or raise ValueError) that initial_state.shape equals (N, HV, V, K) when
transpose_state_layout is True and equals (N, HV, K, V) when
transpose_state_layout is False, with a clear error message referencing
transpose_state_layout, initial_state and expected shape; apply this check
around the branch that sets final_state (the logic using transpose_state_layout,
inplace_final_state, output_final_state and final_state) so mismatched caches
are rejected immediately.
tests/ops/test_kda.py (1)

163-171: ⚠️ Potential issue | 🟡 Minor

Add Intel Alchemist skip guard for consistency.

This test is missing the IS_INTEL_ALCHEMIST guard that test_fused_recurrent uses. Without it, the test will fail on Alchemist GPUs with D > 128 instead of being skipped.

Suggested fix
 def test_fused_recurrent_transpose_state(
     B: int,
     T: int,
     H: int,
     D: int,
     scale: float,
     gate_logit_normalizer: float,
     dtype: torch.dtype,
 ):
     torch.manual_seed(42)
+    if IS_INTEL_ALCHEMIST and D > 128:
+        pytest.skip(reason="fused_recurrent_kda is not supported on alchemist for D>128")
     q = torch.rand(B, T, H, D, dtype=dtype)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/ops/test_kda.py` around lines 163 - 171, This test block lacks the
IS_INTEL_ALCHEMIST guard that causes tests to skip on Alchemist GPUs when D >
128; add the same check used in test_fused_recurrent: before creating tensors
(before torch.manual_seed(42)), check if IS_INTEL_ALCHEMIST and D > 128 and call
pytest.skip with a short message, and ensure IS_INTEL_ALCHEMIST (and pytest if
not already imported) is available in the test module so the guard can be
applied exactly where q, k, v, g, beta, h0_kv, h0_vk are constructed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Line 313: The wrapper must validate that the incoming initial_state shape
matches transpose_state_layout: when transpose_state_layout is False expect
cache tail layout [N, H, K, V] and when True expect [N, H, V, K]; update the
code that handles the initial_state (referencing the transpose_state_layout flag
and the initial_state variable/parameter in the gated-delta/chunk wrapper) to
check the dimensionality and the order of the last two dims and raise a clear
error if they mismatch (include expected vs actual shapes in the message) so an
old [N, H, K, V] passed with transpose_state_layout=True is detected rather than
silently misread by the Triton kernels.

In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 171-185: When transpose_state_layout=True, validate that any
provided initial_state has the transposed tail layout [N, HV, V, K] (not [N, HV,
K, V]) and raise a clear ValueError if it does not; locate the logic around
transpose_state_layout, initial_state, and final_state in fused_recurrent.py
(the block that allocates final_state when output_final_state is true) and add a
shape/layout check before using or allocating final_state so a stale [N, HV, K,
V] cache is rejected rather than silently misread by the kernel.
- Around line 114-126: The loads for p_gk/p_gv should be masked to avoid reading
padded lanes (which BK/BV round up) and injecting garbage into exp multipliers;
modify the p_gk and p_gv loads so that tl.load (or equivalent) is passed a mask
that zeros out indices beyond the true head size (or load full then set masked
entries to 0), and ensure the mask shape matches the TRANSPOSE_STATE branching
(i.e., shape the mask as [None, :] vs [:, None] to match how b_gk/b_gv are
broadcast into b_h); apply this for both USE_GK (b_gk from p_gk) and USE_GV
(b_gv from p_gv) before calling exp and multiplying into b_h.

In `@fla/ops/kda/fused_recurrent.py`:
- Around line 239-262: The code allows an initial_state whose memory layout
doesn't match transpose_state_layout, causing silent reinterpretation; add a
fast-fail shape check before using initial_state: compute N, HV, K, V as in the
function and if initial_state is not None assert (or raise ValueError) that
initial_state.shape equals (N, HV, V, K) when transpose_state_layout is True and
equals (N, HV, K, V) when transpose_state_layout is False, with a clear error
message referencing transpose_state_layout, initial_state and expected shape;
apply this check around the branch that sets final_state (the logic using
transpose_state_layout, inplace_final_state, output_final_state and final_state)
so mismatched caches are rejected immediately.

In `@tests/ops/test_kda.py`:
- Around line 163-171: This test block lacks the IS_INTEL_ALCHEMIST guard that
causes tests to skip on Alchemist GPUs when D > 128; add the same check used in
test_fused_recurrent: before creating tensors (before torch.manual_seed(42)),
check if IS_INTEL_ALCHEMIST and D > 128 and call pytest.skip with a short
message, and ensure IS_INTEL_ALCHEMIST (and pytest if not already imported) is
available in the test module so the guard can be applied exactly where q, k, v,
g, beta, h0_kv, h0_vk are constructed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 39ec030b-a624-4bef-b850-d1e1a77bf52b

📥 Commits

Reviewing files that changed from the base of the PR and between d4b4909 and 283a560.

📒 Files selected for processing (16)
  • fla/ops/common/backends/intracard.py
  • fla/ops/common/chunk_delta_h.py
  • fla/ops/common/chunk_o.py
  • fla/ops/common/intracard_cp.py
  • fla/ops/cp/chunk_delta_h.py
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/gated_delta_rule/fused_recurrent.py
  • fla/ops/gla/chunk.py
  • fla/ops/kda/chunk.py
  • fla/ops/kda/chunk_bwd.py
  • fla/ops/kda/chunk_fwd.py
  • fla/ops/kda/fused_recurrent.py
  • tests/context_parallel/test_cp_gdn.py
  • tests/context_parallel/test_cp_kda.py
  • tests/ops/test_gated_delta.py
  • tests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • fla/ops/kda/chunk_bwd.py
  • tests/context_parallel/test_cp_kda.py

@zhiyuan1i zhiyuan1i merged commit 5dfae2f into main Mar 10, 2026
4 checks passed
@zhiyuan1i zhiyuan1i deleted the lzy/transpose-kv branch March 10, 2026 13:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant