[KDA][GDN] Support transpose_state_layout for [V,K] state memory layout#776
[KDA][GDN] Support transpose_state_layout for [V,K] state memory layout#776
Conversation
WalkthroughAdds an optional transpose_state_layout: bool flag threaded through Python APIs, Intracard/CP backends, Triton kernels (TRANSPOSE_STATE constexpr / autotune keys), and tests to enable an alternate transposed internal state layout; default behavior unchanged when False. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Operator call
participant PyAPI as Python API / Wrapper
participant Backend as Intracard/CP Backend
participant Kernel as Triton Kernel
participant Device as GPU Memory
User->>PyAPI: call op(..., transpose_state_layout=flag)
PyAPI->>Backend: forward flag, prepare/reshape tensors (K,V) or (V,K)
Backend->>Kernel: launch kernel with TRANSPOSE_STATE=flag
Kernel->>Device: read/write state buffers (layout depends on flag)
Kernel-->>Backend: return outputs + final_state (shaped per flag)
Backend-->>PyAPI: return outputs (final_state in chosen layout)
PyAPI-->>User: return results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization by allowing state tensors in KDA and GDN operations to adopt a transposed memory layout. This change aims to enhance memory access patterns, which can be crucial for performance-sensitive applications. The new Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for a transposed state memory layout ([V,K]) across various components for KDA and GDN to improve memory access patterns. The changes are extensive, consistently propagating the transpose_state_layout flag from high-level functions down to the Triton kernels. The logic for handling both memory layouts appears correct, and the PR includes new tests to verify this functionality.
My main feedback concerns the significant code duplication introduced in several Triton kernels. Many if/else blocks for the new TRANSPOSE_STATE flag repeat large chunks of code with only minor differences. While this might be necessary to some extent due to Triton's constraints, there are opportunities to refactor and reduce this duplication, which would greatly improve the code's readability and maintainability. I've added specific comments with suggestions on how to approach this refactoring.
Note: Security Review did not run due to the size of the PR.
| if TRANSPOSE_STATE: | ||
| b_h1 = tl.zeros([BV, 64], dtype=tl.float32) | ||
| if K > 64: | ||
| b_h2 = tl.zeros([BV, 64], dtype=tl.float32) | ||
| if K > 128: | ||
| b_h3 = tl.zeros([BV, 64], dtype=tl.float32) | ||
| if K > 192: | ||
| b_h4 = tl.zeros([BV, 64], dtype=tl.float32) | ||
| else: | ||
| b_h1 = tl.zeros([64, BV], dtype=tl.float32) | ||
| if K > 64: | ||
| b_h2 = tl.zeros([64, BV], dtype=tl.float32) | ||
| if K > 128: | ||
| b_h3 = tl.zeros([64, BV], dtype=tl.float32) | ||
| if K > 192: | ||
| b_h4 = tl.zeros([64, BV], dtype=tl.float32) |
There was a problem hiding this comment.
There's significant code duplication here for initializing the state tensors b_h1, b_h2, etc., based on TRANSPOSE_STATE. This can be refactored to improve readability and maintainability by defining the shape conditionally. Since TRANSPOSE_STATE is a constexpr, the compiler can optimize this, but the duplicated code makes it harder for developers to maintain.
Consider defining the shape dimensions conditionally and reusing them:
shape_dim0 = BV if TRANSPOSE_STATE else 64
shape_dim1 = 64 if TRANSPOSE_STATE else BV
b_h1 = tl.zeros([shape_dim0, shape_dim1], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([shape_dim0, shape_dim1], dtype=tl.float32)
# ... and so onThis pattern of duplication appears throughout the file and could be similarly refactored.
| if USE_INITIAL_STATE: | ||
| p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) | ||
| if TRANSPOSE_STATE: | ||
| p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) | ||
| else: | ||
| p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) | ||
| b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) | ||
| if K > 64: | ||
| p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) | ||
| if TRANSPOSE_STATE: | ||
| p_h0_2 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) | ||
| else: | ||
| p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) | ||
| b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) | ||
| if K > 128: | ||
| p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) | ||
| if TRANSPOSE_STATE: | ||
| p_h0_3 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) | ||
| else: | ||
| p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) | ||
| b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) | ||
| if K > 192: | ||
| p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) | ||
| if TRANSPOSE_STATE: | ||
| p_h0_4 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) | ||
| else: | ||
| p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) | ||
| b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) |
There was a problem hiding this comment.
This block for loading the initial state has a lot of duplicated code due to the TRANSPOSE_STATE flag. This pattern of duplication appears in several other places in this kernel (e.g., main recurrence loop, final state storing). While tl.constexpr helps the compiler, it makes the code harder to read and maintain. A bug fix in one path might be missed in the other.
Consider refactoring to reduce this duplication. For example, you could define the parameters for tl.make_block_ptr conditionally at the beginning of the if USE_INITIAL_STATE: block, and then reuse them. This would make the logic for different K values much cleaner.
| if TRANSPOSE_STATE: | ||
| b_dh1 = tl.zeros([BV, 64], dtype=tl.float32) | ||
| if K > 64: | ||
| b_dh2 = tl.zeros([BV, 64], dtype=tl.float32) | ||
| if K > 128: | ||
| b_dh3 = tl.zeros([BV, 64], dtype=tl.float32) | ||
| if K > 192: | ||
| b_dh4 = tl.zeros([BV, 64], dtype=tl.float32) | ||
| else: | ||
| b_dh1 = tl.zeros([64, BV], dtype=tl.float32) | ||
| if K > 64: | ||
| b_dh2 = tl.zeros([64, BV], dtype=tl.float32) | ||
| if K > 128: | ||
| b_dh3 = tl.zeros([64, BV], dtype=tl.float32) | ||
| if K > 192: | ||
| b_dh4 = tl.zeros([64, BV], dtype=tl.float32) |
There was a problem hiding this comment.
Similar to the forward kernel, there's duplicated logic for initializing b_dh1, b_dh2, etc. This reduces readability and increases the chance of introducing bugs if one path is modified but the other is not. Please consider refactoring to consolidate the common logic and handle the differences based on TRANSPOSE_STATE more concisely.
| if HAS_H0: | ||
| orig_seq_id = tl.load(h0_seq_ids + i_seq).to(tl.int32) | ||
| p_h0 = tl.make_block_ptr( | ||
| h0 + (orig_seq_id * H + i_h) * K * V, | ||
| (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0) | ||
| ) | ||
| b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) | ||
| if TRANSPOSE_STATE: | ||
| p_h0 = tl.make_block_ptr( | ||
| h0 + (orig_seq_id * H + i_h) * V * K, | ||
| (V, K), (K, 1), (i_v * BV, 0), (BV, BK), (1, 0) | ||
| ) | ||
| b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) | ||
| else: | ||
| p_h0 = tl.make_block_ptr( | ||
| h0 + (orig_seq_id * H + i_h) * K * V, | ||
| (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0) | ||
| ) | ||
| b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) | ||
| else: | ||
| b_h = tl.zeros([BK, BV], dtype=tl.float32) | ||
| if TRANSPOSE_STATE: | ||
| b_h = tl.zeros([BV, BK], dtype=tl.float32) | ||
| else: | ||
| b_h = tl.zeros([BK, BV], dtype=tl.float32) |
There was a problem hiding this comment.
This block for initializing the state has duplicated code for handling TRANSPOSE_STATE. This pattern continues in the merge loop. To improve maintainability, consider refactoring to reduce this duplication. You could, for instance, set up shape and offset-related variables conditionally at the beginning of the if INTRACARD_MODE: block and reuse them.
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/ops/cp/chunk_delta_h.py (1)
953-992:⚠️ Potential issue | 🔴 CriticalOffset
gkbefore these backward-decay loads.These
tl.load(gk + last_idx * H * K + ...)expressions never incorporatebosori_h, so any non-zero sequence offset or head index reads decays from the wrong slice. In CP KDA backward this corrupts the preprocesseddht/initial_statefor every head except the first one.🩹 Proposed fix
q += ((bos * H + i_h) * K).to(tl.int64) k += ((bos * H + i_h) * K).to(tl.int64) w += ((bos * H + i_h) * K).to(tl.int64) + if USE_GK: + gk += ((bos * H + i_h) * K).to(tl.int64) dhm += i_h * K * (V + K) stride_k = H * K🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/cp/chunk_delta_h.py` around lines 953 - 992, The backward loads for gk use tl.load(gk + last_idx * H * K + o_...) which ignores sequence offset (bos) and head index (i_h); update all tl.load calls that compute b_gk_last1/2/3/4 to base the address on both bos and i_h (e.g. replace last_idx * H * K with (bos + last_idx) * H * K + i_h * K) so each load becomes tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_kN, ...), ensuring the gk slice is correctly offset per sequence and head for b_gk_last1, b_gk_last2, b_gk_last3, b_gk_last4.
🧹 Nitpick comments (1)
tests/ops/test_gated_delta.py (1)
236-296: Consider adding gradient verification for fused recurrent transpose test.The
test_fused_recurrent_transpose_stateonly verifies forward outputs. The chunk version includes gradient checks. Consider adding backward pass verification for completeness, especially sincefused_recurrent_gated_delta_rulesupports gradients.💡 Suggested enhancement
def test_fused_recurrent_transpose_state( ... ): ... - q, k, v, beta, g, h0_kv, h0_vk = map(lambda x: x.to(device), (q, k, v, beta, g, h0_kv, h0_vk)) + q, k, v, beta, g, h0_kv, h0_vk = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, beta, g, h0_kv, h0_vk)) ref, ref_ht = fused_recurrent_gated_delta_rule(...) tri, tri_ht = fused_recurrent_gated_delta_rule(...) + + do = torch.randn_like(ref) + dht_vk = torch.randn(B, HV, D, D, dtype=torch.float32, device=device) + dht_kv = dht_vk.transpose(-1, -2).contiguous() + + ((tri * do).sum() + (tri_ht * dht_vk).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv = q.grad, k.grad, v.grad + q.grad = k.grad = v.grad = None + + ((ref * do).sum() + (ref_ht * dht_kv).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv = q.grad, k.grad, v.grad + + assert_close('dq', ref_dq, tri_dq, 1e-4) + # ... additional gradient checks + assert_close('o', ref, tri, 1e-4) assert_close('ht', ref_ht, tri_ht.transpose(-1, -2), 1e-4)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_gated_delta.py` around lines 236 - 296, The test test_fused_recurrent_transpose_state only checks forward outputs; add backward/gradient verification by computing a scalar loss (e.g., sum of outputs or dot with random upstream grads) for both calls to fused_recurrent_gated_delta_rule (the ref call with transpose_state_layout=False and the tri call with transpose_state_layout=True), call backward() to get gradients for inputs (q, k, v, beta, g and initial_state/h0_vk/h0_kv as applicable), and assert that corresponding gradients match (e.g., compare q.grad, k.grad, v.grad and initial_state.grad vs tri's grads after appropriate transpose) within the same tolerance used for forward (use assert_close on gradients); ensure you .clone().detach().requires_grad_(True) the inputs before calling the functions so gradients are tracked.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Line 313: The wrapper currently accepts transpose_state_layout: bool and then
only checks initial_state.shape[0], which lets a mismatched layout ([N,H,K,V] vs
expected [N,H,V,K]) slip through; update the wrapper to validate
initial_state.ndim and the ordering of dimensions when transpose_state_layout is
True by asserting (or raising a clear ValueError) that initial_state has four
dims and that its shape matches the expected [N, H, V, K] layout (or the
non-transposed [N, H, K, V] when False), referencing the transpose_state_layout
flag and initial_state parameter so misuse fails fast rather than silently
producing wrong outputs/gradients.
- Line 238: ChunkGatedDeltaRuleFunction.forward was extended to accept
transpose_state_layout (making 13 inputs) but backward still returns only 12
gradients; update ChunkGatedDeltaRuleFunction.backward to include an extra
gradient placeholder (e.g., None) corresponding to the transpose_state_layout
input so the returned tuple length matches the forward inputs. Locate the
backward implementation and add the additional None in the correct position in
the returned gradients tuple (matching the transpose_state_layout parameter) to
avoid PyTorch's incorrect-gradient-count error.
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Line 171: When transpose_state_layout is True in
fused_recurrent_gated_delta_rule_fwd (and the analogous call around lines
181-185), validate that initial_state has the expected tail shape/layout ([N,
HV, V, K]) before proceeding; detect mismatches where initial_state is in [N,
HV, K, V] (or any other incompatible tail shape) and raise a clear ValueError
indicating the expected vs actual tail dims. Implement this check immediately
after reading transpose_state_layout and before any pointer math or kernel calls
so the kernel never silently misreads a mismatched cache; reference the
transpose_state_layout parameter and the initial_state tensor in your error
message for clarity.
In `@fla/ops/kda/chunk.py`:
- Line 163: chunk_kda currently only validates sequence count and dtype but not
the cache memory layout, so if transpose_state_layout is toggled it can silently
reinterpret memory in chunk_kda_fwd / chunk_kda_bwd; update chunk_kda() to
detect and reject mismatched cache layouts by checking the cache tensor's
shape/order against the transpose_state_layout flag (e.g., if
transpose_state_layout is False expect [N, S, H, K, V] layout or if True expect
[N, H, K, V] etc.), and raise a clear error when the actual layout doesn't match
the flag before calling chunk_kda_fwd / chunk_kda_bwd; reference the
transpose_state_layout parameter and the cache input used by chunk_kda(), and
add the guard early in chunk_kda() so downstream fwd/bwd kernels never
reinterpret memory incorrectly.
In `@fla/ops/kda/fused_recurrent.py`:
- Line 239: The code accepts transpose_state_layout but never validates
initial_state layout, so a stale [N, HV, K, V] buffer can be silently
reinterpreted as [N, HV, V, K]; add an explicit fast-fail validation in both the
inplace and regular paths (where transpose_state_layout is used) that checks
initial_state's dimensions/order against the expected final layout when
transpose_state_layout is True (expected axes [N, HV, V, K]) and raise a clear
ValueError if they mismatch; alternatively, if safe, perform an explicit
transpose/reorder of initial_state into the required layout before caching/using
it—apply this check/reorder for the code paths around the transpose_state_layout
flag and any functions that consume initial_state so the kernel never receives a
stale layout.
---
Outside diff comments:
In `@fla/ops/cp/chunk_delta_h.py`:
- Around line 953-992: The backward loads for gk use tl.load(gk + last_idx * H *
K + o_...) which ignores sequence offset (bos) and head index (i_h); update all
tl.load calls that compute b_gk_last1/2/3/4 to base the address on both bos and
i_h (e.g. replace last_idx * H * K with (bos + last_idx) * H * K + i_h * K) so
each load becomes tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_kN, ...),
ensuring the gk slice is correctly offset per sequence and head for b_gk_last1,
b_gk_last2, b_gk_last3, b_gk_last4.
---
Nitpick comments:
In `@tests/ops/test_gated_delta.py`:
- Around line 236-296: The test test_fused_recurrent_transpose_state only checks
forward outputs; add backward/gradient verification by computing a scalar loss
(e.g., sum of outputs or dot with random upstream grads) for both calls to
fused_recurrent_gated_delta_rule (the ref call with transpose_state_layout=False
and the tri call with transpose_state_layout=True), call backward() to get
gradients for inputs (q, k, v, beta, g and initial_state/h0_vk/h0_kv as
applicable), and assert that corresponding gradients match (e.g., compare
q.grad, k.grad, v.grad and initial_state.grad vs tri's grads after appropriate
transpose) within the same tolerance used for forward (use assert_close on
gradients); ensure you .clone().detach().requires_grad_(True) the inputs before
calling the functions so gradients are tracked.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d42dc09b-bf3d-46e0-946c-920a22ec96c9
📒 Files selected for processing (16)
fla/ops/common/backends/intracard.pyfla/ops/common/chunk_delta_h.pyfla/ops/common/chunk_o.pyfla/ops/common/intracard_cp.pyfla/ops/cp/chunk_delta_h.pyfla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/fused_recurrent.pyfla/ops/gla/chunk.pyfla/ops/kda/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_fwd.pyfla/ops/kda/fused_recurrent.pytests/context_parallel/test_cp_gdn.pytests/context_parallel/test_cp_kda.pytests/ops/test_gated_delta.pytests/ops/test_kda.py
| cu_seqlens: torch.LongTensor | None = None, | ||
| cu_seqlens_cpu: torch.LongTensor | None = None, | ||
| cp_context: FLACPContext | None = None, | ||
| transpose_state_layout: bool = False, |
There was a problem hiding this comment.
Validate initial_state for the selected layout.
This flag changes the expected cache shape to [N, H, V, K], but the wrapper still only checks initial_state.shape[0]. Reusing an old [N, H, K, V] cache with transpose_state_layout=True will be silently misread by the Triton kernels and return wrong outputs/gradients instead of failing fast.
🩹 Proposed fix
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.",
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.",
)
+ if initial_state is not None:
+ expected_tail = (
+ (q.shape[2], v.shape[-1], k.shape[-1])
+ if transpose_state_layout
+ else (q.shape[2], k.shape[-1], v.shape[-1])
+ )
+ if initial_state.shape[1:] != expected_tail:
+ raise ValueError(
+ f"`initial_state` must have shape [N, {expected_tail[0]}, {expected_tail[1]}, {expected_tail[2]}], "
+ f"got {tuple(initial_state.shape)}."
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/chunk.py` at line 313, The wrapper currently accepts
transpose_state_layout: bool and then only checks initial_state.shape[0], which
lets a mismatched layout ([N,H,K,V] vs expected [N,H,V,K]) slip through; update
the wrapper to validate initial_state.ndim and the ordering of dimensions when
transpose_state_layout is True by asserting (or raising a clear ValueError) that
initial_state has four dims and that its shape matches the expected [N, H, V, K]
layout (or the non-transposed [N, H, K, V] when False), referencing the
transpose_state_layout flag and initial_state parameter so misuse fails fast
rather than silently producing wrong outputs/gradients.
| output_final_state: bool = False, | ||
| use_qk_l2norm_in_kernel: bool = False, | ||
| cu_seqlens: torch.LongTensor | None = None, | ||
| transpose_state_layout: bool = False, |
There was a problem hiding this comment.
Validate initial_state when transpose mode is enabled.
fused_recurrent_gated_delta_rule_fwd now flips the state layout to [N, HV, V, K], but it still accepts any initial_state tail shape. Because the kernel switches pointer math between [K, V] and [V, K], a mismatched cache will be silently misread rather than rejected.
Also applies to: 181-185
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/fused_recurrent.py` at line 171, When
transpose_state_layout is True in fused_recurrent_gated_delta_rule_fwd (and the
analogous call around lines 181-185), validate that initial_state has the
expected tail shape/layout ([N, HV, V, K]) before proceeding; detect mismatches
where initial_state is in [N, HV, K, V] (or any other incompatible tail shape)
and raise a clear ValueError indicating the expected vs actual tail dims.
Implement this check immediately after reading transpose_state_layout and before
any pointer math or kernel calls so the kernel never silently misreads a
mismatched cache; reference the transpose_state_layout parameter and the
initial_state tensor in your error message for clarity.
| disable_recompute: bool = False, | ||
| return_intermediate_states: bool = False, | ||
| cp_context: FLACPContext = None, | ||
| transpose_state_layout: bool = False, |
There was a problem hiding this comment.
Reject mismatched cache layouts at the KDA entrypoint.
transpose_state_layout is public now, but chunk_kda() still only validates the sequence count and dtype. Passing a [N, H, K, V] cache into transpose mode (or the inverse) will silently reinterpret memory in chunk_kda_fwd / chunk_kda_bwd and return incorrect activations and gradients.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/kda/chunk.py` at line 163, chunk_kda currently only validates
sequence count and dtype but not the cache memory layout, so if
transpose_state_layout is toggled it can silently reinterpret memory in
chunk_kda_fwd / chunk_kda_bwd; update chunk_kda() to detect and reject
mismatched cache layouts by checking the cache tensor's shape/order against the
transpose_state_layout flag (e.g., if transpose_state_layout is False expect [N,
S, H, K, V] layout or if True expect [N, H, K, V] etc.), and raise a clear error
when the actual layout doesn't match the flag before calling chunk_kda_fwd /
chunk_kda_bwd; reference the transpose_state_layout parameter and the cache
input used by chunk_kda(), and add the guard early in chunk_kda() so downstream
fwd/bwd kernels never reinterpret memory incorrectly.
| use_gate_in_kernel: bool = False, | ||
| lower_bound: float | None = None, | ||
| out: torch.Tensor | None = None, | ||
| transpose_state_layout: bool = False, |
There was a problem hiding this comment.
Fail fast on recurrent-state layout mismatches.
This path now allocates [N, HV, V, K] final states in transpose mode, but a provided initial_state is never checked against that layout. In both the inplace and regular paths, a stale [N, HV, K, V] cache will be silently reinterpreted by the kernel.
Also applies to: 259-262
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/kda/fused_recurrent.py` at line 239, The code accepts
transpose_state_layout but never validates initial_state layout, so a stale [N,
HV, K, V] buffer can be silently reinterpreted as [N, HV, V, K]; add an explicit
fast-fail validation in both the inplace and regular paths (where
transpose_state_layout is used) that checks initial_state's dimensions/order
against the expected final layout when transpose_state_layout is True (expected
axes [N, HV, V, K]) and raise a clear ValueError if they mismatch;
alternatively, if safe, perform an explicit transpose/reorder of initial_state
into the required layout before caching/using it—apply this check/reorder for
the code paths around the transpose_state_layout flag and any functions that
consume initial_state so the kernel never receives a stale layout.
23b84a2 to
b574148
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
fla/ops/common/chunk_o.py (1)
138-170:⚠️ Potential issue | 🔴 CriticalComplete the transpose rollout for the
dvpath.
chunk_bwd_dqkwgis now layout-aware, butchunk_bwd_kernel_dv/chunk_bwd_dvstill hard-codedhas(K, V)at Lines 379-381 and Lines 542-593. Withtranspose_state_layout=True,fla/ops/common/chunk_delta_h.pynow producesdhin[V, K], so this branch will read the wrong strides and return incorrectdv.Also applies to: 662-711
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/common/chunk_o.py` around lines 138 - 170, The dv-backward path still assumes dh has layout (K, V) causing wrong strides when transpose_state_layout is enabled; update the transpose rollout in the dv kernels so they honor the layout flag. Specifically, in chunk_bwd_kernel_dv and the caller chunk_bwd_dv, branch on the TRANSPOSE_STATE (or transpose_state_layout) constexpr and compute dh strides/shape and the indexing into dh as [K, V] when false and [V, K] when true (matching fla/ops/common/chunk_delta_h.py output); adjust any temporary views/loads and the accumulation into dv accordingly so all stride calculations and memory accesses use the correct dimension ordering under both layouts. Ensure the same fix is applied to the other occurrence noted (the later chunk_bwd_kernel/dv block) so both code paths mirror the layout-aware logic used for dqkwg.fla/ops/kda/chunk.py (1)
140-141:⚠️ Potential issue | 🔴 CriticalAdd missing gradient placeholder in backward return tuple.
ChunkKDAFunction.forward()takes 20 inputs afterctx, butbackward()at lines 140-141 returns only 19 gradients. PyTorch will fail with an arity mismatch error. Add one moreNonegradient to the return tuple:- return (dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), dA, dbias, None, dh0, - None, None, None, None, None, None, None, None, None, None) + return (dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), dA, dbias, None, dh0, + None, None, None, None, None, None, None, None, None, None, None)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/chunk.py` around lines 140 - 141, The backward() in ChunkKDAFunction returns 19 gradients but forward() accepted 20 inputs after ctx; update ChunkKDAFunction.backward() return tuple (the line returning (dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), dA, dbias, None, dh0, ...)) to include one additional None so the returned tuple has 20 entries to match forward()'s inputs (i.e., append another None in the return tuple).
♻️ Duplicate comments (5)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
171-187:⚠️ Potential issue | 🟠 MajorValidate
initial_statefor the selected layout.This path now allocates
[N, HV, V, K]in transpose mode, but a providedinitial_stateis still never checked. A stale cache — or a non-contiguous transpose view — will be silently reinterpreted by the kernel.🩹 Proposed fix
o = torch.empty_like(v) + if initial_state is not None: + expected_shape = ( + (N, HV, V, K) + if transpose_state_layout + else (N, HV, K, V) + ) + if ( + initial_state.ndim != 4 + or tuple(initial_state.shape) != expected_shape + or not initial_state.is_contiguous() + ): + raise ValueError( + f"`initial_state` must be contiguous and have shape {expected_shape} when " + f"`transpose_state_layout={transpose_state_layout}`, got shape " + f"{tuple(initial_state.shape)} with strides {tuple(initial_state.stride())}." + ) if output_final_state: if transpose_state_layout: final_state = q.new_empty(N, HV, V, K, dtype=torch.float32) else:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 171 - 187, The code allocates final_state with different memory layouts when transpose_state_layout is True (q.new_empty(N, HV, V, K)) vs False (q.new_empty(N, HV, K, V)) but never validates a provided initial_state; update the logic in the function that handles initial_state/ final_state (look for initial_state, transpose_state_layout, output_final_state, final_state) to check that when initial_state is not None its shape and memory layout match the expected layout for the chosen transpose_state_layout (e.g., exact shape [N, HV, V, K] for transpose_mode or [N, HV, K, V] otherwise), raise a clear error if mismatched, and ensure any view/transpose is made contiguous (or copy to the expected layout) before passing to downstream kernels.fla/ops/gated_delta_rule/chunk.py (2)
291-296:⚠️ Potential issue | 🔴 CriticalAdd the missing backward slot for
transpose_state_layout.Line 296 still returns 12 gradients for a
forward()with 13 inputs afterctx, so PyTorch will fail the first backward pass with a gradient-arity error.Verification script
#!/bin/bash python - <<'PY' import ast from pathlib import Path path = Path("fla/ops/gated_delta_rule/chunk.py") tree = ast.parse(path.read_text()) cls = next(n for n in tree.body if isinstance(n, ast.ClassDef) and n.name == "ChunkGatedDeltaRuleFunction") fwd = next(n for n in cls.body if isinstance(n, ast.FunctionDef) and n.name == "forward") bwd = next(n for n in cls.body if isinstance(n, ast.FunctionDef) and n.name == "backward") ret = next(n for n in ast.walk(bwd) if isinstance(n, ast.Return)) print("forward inputs:", len(fwd.args.args) - 1) print("backward outputs:", len(ret.value.elts)) # Expect these counts to match. PY🩹 Proposed fix
- return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None, None, None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk.py` around lines 291 - 296, The backward() return tuple is missing a gradient slot for the forward() input transpose_state_layout, causing a mismatch in arity; update ChunkGatedDeltaRuleFunction.backward to include the missing gradient position (most likely a None if no gradient is required) corresponding to transpose_state_layout so the number of returned gradients equals the number of forward inputs after ctx; locate the return in backward() and insert the appropriate None (or computed grad) at the position matching transpose_state_layout.
313-418:⚠️ Potential issue | 🟠 MajorValidate the state tensor contract before dispatch.
transpose_state_layoutflips the expected cache layout, but this wrapper still only checks sequence count. A stale cache — or a non-contiguousinitial_state.transpose(-1, -2)view — will be silently misread by the Triton kernels.🩹 Proposed fix
if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing.", ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.", ) + if initial_state is not None: + expected_n = len(cu_seqlens) - 1 if cu_seqlens is not None else q.shape[0] + expected_shape = ( + (expected_n, q.shape[2], v.shape[-1], k.shape[-1]) + if transpose_state_layout + else (expected_n, q.shape[2], k.shape[-1], v.shape[-1]) + ) + if ( + initial_state.ndim != 4 + or tuple(initial_state.shape) != expected_shape + or not initial_state.is_contiguous() + ): + raise ValueError( + f"`initial_state` must be contiguous and have shape {expected_shape} when " + f"`transpose_state_layout={transpose_state_layout}`, got shape " + f"{tuple(initial_state.shape)} with strides {tuple(initial_state.stride())}." + ) if scale is None: scale = k.shape[-1] ** -0.5🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk.py` around lines 313 - 418, The wrapper must validate the initial_state tensor's layout and shape before calling ChunkGatedDeltaRuleFunction.apply: check initial_state is a 4-D tensor with initial_state.shape[0] == (len(cu_seqlens)-1 if cu_seqlens is given else q.shape[0]), and that its last two dims match K and V taking transpose_state_layout into account (i.e., expect [..., K, V] when transpose_state_layout=False and [..., V, K] when True); also ensure the tensor is contiguous in the memory layout the Triton kernel expects (use .is_contiguous() or verify strides) and if it isn’t, make a contiguous copy or explicitly transpose+contiguous so the dispatched kernel never receives a non-contiguous/transposed view. Apply these checks/normalization right before calling ChunkGatedDeltaRuleFunction.apply.fla/ops/kda/chunk.py (1)
163-338:⚠️ Potential issue | 🟠 MajorValidate the state tensor contract before calling
chunk_kda_fwd.
transpose_state_layoutchanges the cache layout, butchunk_kda()still only checks dtype and sequence count. A stale cache — or a non-contiguousinitial_state.transpose(-1, -2)view — will be silently reinterpreted downstream.🩹 Proposed fix
if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing.", @@ ) if initial_state is not None: assert initial_state.dtype == torch.float32, "initial_state must be in float32." + expected_n = len(cu_seqlens) - 1 if cu_seqlens is not None else q.shape[0] + expected_shape = ( + (expected_n, q.shape[2], v.shape[-1], k.shape[-1]) + if transpose_state_layout + else (expected_n, q.shape[2], k.shape[-1], v.shape[-1]) + ) + if ( + initial_state.ndim != 4 + or tuple(initial_state.shape) != expected_shape + or not initial_state.is_contiguous() + ): + raise ValueError( + f"`initial_state` must be contiguous and have shape {expected_shape} when " + f"`transpose_state_layout={transpose_state_layout}`, got shape " + f"{tuple(initial_state.shape)} with strides {tuple(initial_state.stride())}." + ) A_log, dt_bias = None, None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/chunk.py` around lines 163 - 338, chunk_kda currently only checks initial_state dtype and count but ignores layout changes caused by transpose_state_layout, risking silent misinterpretation; before calling ChunkKDAFunction.apply, validate initial_state more strictly: confirm dtype is torch.float32, confirm initial_state.shape matches the expected layout ([N, H, K, V] when transpose_state_layout=False or [N, H, V, K] when True), confirm the first dimension equals len(cu_seqlens)-1 when cu_seqlens is provided, and ensure the tensor is contiguous (or explicitly require/clone to contiguous) so a transposed view cannot be passed through; perform these checks inside chunk_kda (near the existing initial_state assertions) and raise descriptive ValueError/assertion mentioning initial_state and transpose_state_layout if the contract is violated.fla/ops/kda/fused_recurrent.py (1)
239-262:⚠️ Potential issue | 🟠 MajorReject recurrent caches with the wrong tail layout.
This helper now produces
[*, HV, V, K]caches in transpose mode, but a providedinitial_stateis still never validated. A stale cache — or a non-contiguous transpose view — will be silently reinterpreted in both the inplace and regular paths.🩹 Proposed fix
if out is None: out = torch.zeros_like(v) else: assert out.shape == v.shape + if initial_state is not None: + expected_tail = ( + (HV, V, K) + if transpose_state_layout + else (HV, K, V) + ) + if ( + initial_state.ndim != 4 + or tuple(initial_state.shape[1:]) != expected_tail + or not initial_state.is_contiguous() + ): + raise ValueError( + f"`initial_state` must be contiguous and end with {expected_tail} when " + f"`transpose_state_layout={transpose_state_layout}`, got shape " + f"{tuple(initial_state.shape)} with strides {tuple(initial_state.stride())}." + ) if inplace_final_state: assert initial_state is not None final_state = initial_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/fused_recurrent.py` around lines 239 - 262, In fused_recurrent.py, validate any provided initial_state against the expected layout before reusing it: compute N and expected shape based on transpose_state_layout (expected = (N, HV, V, K) when transpose_state_layout is True, otherwise (N, HV, K, V)), then assert initial_state is not None implies initial_state.shape == expected, initial_state.device == q.device and initial_state.dtype == q.dtype (or float32 if q is cast), and that initial_state.is_contiguous() (or use .contiguous() only after cloning) to avoid silently reinterpreting a non-contiguous transpose view; raise a clear ValueError/AssertionError if the checks fail, and in the non-inplace path consider cloning/copying a validated initial_state into final_state instead of reinterpreting it.
🧹 Nitpick comments (1)
tests/context_parallel/test_cp_gdn.py (1)
93-103: Use a rectangular K/V case in the transpose-state suite.The harness still takes a single
D, so every transpose-state run here hasK == V. That masks exactly the[K, V]vs[V, K]mix-ups this flag is meant to catch; please split the helper into separateK/Vsizes and add at least oneK != Vtranspose case.Also applies to: 275-295, 390-417
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/context_parallel/test_cp_gdn.py` around lines 93 - 103, The helper run_cp_gdn_test_worker currently takes a single D so K==V for all transpose-state tests; change its signature to accept separate K and V sizes (e.g., add parameters K: int, V: int) and update its internal construction of key/value tensors and any shapes that used D to use K and V respectively; update all call sites (including the other two helper occurrences noted) to pass distinct K and V where you want a rectangular case and add at least one test invocation with K != V to the transpose-state suite so the [K,V] vs [V,K] mix-up is exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/ops/test_gated_delta.py`:
- Around line 190-192: The test uses all-zero initial states h0_kv and h0_vk
which masks transpose_state_layout bugs; change the initialization of h0_kv (and
its transposed h0_vk) to a non-zero, deterministic tensor (e.g., seeded
torch.randn or a pattern like arange/ones scaled by indices) so that h0_kv != 0
and h0_vk = h0_kv.transpose(-1,-2).contiguous() and still pass both through the
existing map that sets .to(device).requires_grad_(True); update the lines
creating h0_kv/h0_vk and their subsequent mapping to ensure the forward path
detects incorrect transposed loads.
In `@tests/ops/test_kda.py`:
- Around line 163-171: This transpose-only test is missing the Intel Alchemist
skip that test_fused_recurrent uses; add the same guard so the test calls
pytest.skip when running on Alchemist with D > 128. Locate the transpose test in
tests/ops/test_kda.py (same scope as the q,k,v,g,beta,h0_kv,h0_vk setup), import
or reuse the existing is_alchemist() helper and add a conditional like `if
is_alchemist() and D > 128: pytest.skip(...)` before creating tensors so the
test mirrors test_fused_recurrent's behavior.
---
Outside diff comments:
In `@fla/ops/common/chunk_o.py`:
- Around line 138-170: The dv-backward path still assumes dh has layout (K, V)
causing wrong strides when transpose_state_layout is enabled; update the
transpose rollout in the dv kernels so they honor the layout flag. Specifically,
in chunk_bwd_kernel_dv and the caller chunk_bwd_dv, branch on the
TRANSPOSE_STATE (or transpose_state_layout) constexpr and compute dh
strides/shape and the indexing into dh as [K, V] when false and [V, K] when true
(matching fla/ops/common/chunk_delta_h.py output); adjust any temporary
views/loads and the accumulation into dv accordingly so all stride calculations
and memory accesses use the correct dimension ordering under both layouts.
Ensure the same fix is applied to the other occurrence noted (the later
chunk_bwd_kernel/dv block) so both code paths mirror the layout-aware logic used
for dqkwg.
In `@fla/ops/kda/chunk.py`:
- Around line 140-141: The backward() in ChunkKDAFunction returns 19 gradients
but forward() accepted 20 inputs after ctx; update ChunkKDAFunction.backward()
return tuple (the line returning (dq.to(q), dk.to(k), dv.to(v), dg.to(g),
db.to(beta), dA, dbias, None, dh0, ...)) to include one additional None so the
returned tuple has 20 entries to match forward()'s inputs (i.e., append another
None in the return tuple).
---
Duplicate comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Around line 291-296: The backward() return tuple is missing a gradient slot
for the forward() input transpose_state_layout, causing a mismatch in arity;
update ChunkGatedDeltaRuleFunction.backward to include the missing gradient
position (most likely a None if no gradient is required) corresponding to
transpose_state_layout so the number of returned gradients equals the number of
forward inputs after ctx; locate the return in backward() and insert the
appropriate None (or computed grad) at the position matching
transpose_state_layout.
- Around line 313-418: The wrapper must validate the initial_state tensor's
layout and shape before calling ChunkGatedDeltaRuleFunction.apply: check
initial_state is a 4-D tensor with initial_state.shape[0] == (len(cu_seqlens)-1
if cu_seqlens is given else q.shape[0]), and that its last two dims match K and
V taking transpose_state_layout into account (i.e., expect [..., K, V] when
transpose_state_layout=False and [..., V, K] when True); also ensure the tensor
is contiguous in the memory layout the Triton kernel expects (use
.is_contiguous() or verify strides) and if it isn’t, make a contiguous copy or
explicitly transpose+contiguous so the dispatched kernel never receives a
non-contiguous/transposed view. Apply these checks/normalization right before
calling ChunkGatedDeltaRuleFunction.apply.
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 171-187: The code allocates final_state with different memory
layouts when transpose_state_layout is True (q.new_empty(N, HV, V, K)) vs False
(q.new_empty(N, HV, K, V)) but never validates a provided initial_state; update
the logic in the function that handles initial_state/ final_state (look for
initial_state, transpose_state_layout, output_final_state, final_state) to check
that when initial_state is not None its shape and memory layout match the
expected layout for the chosen transpose_state_layout (e.g., exact shape [N, HV,
V, K] for transpose_mode or [N, HV, K, V] otherwise), raise a clear error if
mismatched, and ensure any view/transpose is made contiguous (or copy to the
expected layout) before passing to downstream kernels.
In `@fla/ops/kda/chunk.py`:
- Around line 163-338: chunk_kda currently only checks initial_state dtype and
count but ignores layout changes caused by transpose_state_layout, risking
silent misinterpretation; before calling ChunkKDAFunction.apply, validate
initial_state more strictly: confirm dtype is torch.float32, confirm
initial_state.shape matches the expected layout ([N, H, K, V] when
transpose_state_layout=False or [N, H, V, K] when True), confirm the first
dimension equals len(cu_seqlens)-1 when cu_seqlens is provided, and ensure the
tensor is contiguous (or explicitly require/clone to contiguous) so a transposed
view cannot be passed through; perform these checks inside chunk_kda (near the
existing initial_state assertions) and raise descriptive ValueError/assertion
mentioning initial_state and transpose_state_layout if the contract is violated.
In `@fla/ops/kda/fused_recurrent.py`:
- Around line 239-262: In fused_recurrent.py, validate any provided
initial_state against the expected layout before reusing it: compute N and
expected shape based on transpose_state_layout (expected = (N, HV, V, K) when
transpose_state_layout is True, otherwise (N, HV, K, V)), then assert
initial_state is not None implies initial_state.shape == expected,
initial_state.device == q.device and initial_state.dtype == q.dtype (or float32
if q is cast), and that initial_state.is_contiguous() (or use .contiguous() only
after cloning) to avoid silently reinterpreting a non-contiguous transpose view;
raise a clear ValueError/AssertionError if the checks fail, and in the
non-inplace path consider cloning/copying a validated initial_state into
final_state instead of reinterpreting it.
---
Nitpick comments:
In `@tests/context_parallel/test_cp_gdn.py`:
- Around line 93-103: The helper run_cp_gdn_test_worker currently takes a single
D so K==V for all transpose-state tests; change its signature to accept separate
K and V sizes (e.g., add parameters K: int, V: int) and update its internal
construction of key/value tensors and any shapes that used D to use K and V
respectively; update all call sites (including the other two helper occurrences
noted) to pass distinct K and V where you want a rectangular case and add at
least one test invocation with K != V to the transpose-state suite so the [K,V]
vs [V,K] mix-up is exercised.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a9fac52e-3dca-4b1e-a740-a94e23338e3f
📒 Files selected for processing (16)
fla/ops/common/backends/intracard.pyfla/ops/common/chunk_delta_h.pyfla/ops/common/chunk_o.pyfla/ops/common/intracard_cp.pyfla/ops/cp/chunk_delta_h.pyfla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/fused_recurrent.pyfla/ops/gla/chunk.pyfla/ops/kda/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_fwd.pyfla/ops/kda/fused_recurrent.pytests/context_parallel/test_cp_gdn.pytests/context_parallel/test_cp_kda.pytests/ops/test_gated_delta.pytests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/context_parallel/test_cp_kda.py
| torch.manual_seed(42) | ||
| q = torch.rand(B, T, H, D, dtype=dtype) | ||
| k = torch.rand(B, T, H, D, dtype=dtype) | ||
| v = torch.rand(B, T, H, D, dtype=dtype) | ||
| g = F.logsigmoid(torch.randn(B, T, H, D, dtype=torch.float)) / gate_logit_normalizer | ||
| beta = torch.randn(B, T, H, dtype=dtype).sigmoid() | ||
| h0_kv = torch.randn(B, H, D, D, dtype=torch.float32) | ||
| h0_vk = h0_kv.transpose(-1, -2).contiguous() | ||
| q, k, v, g, beta, h0_kv, h0_vk = map(lambda x: x.to(device), (q, k, v, g, beta, h0_kv, h0_vk)) |
There was a problem hiding this comment.
Mirror the existing Alchemist guard in this transpose test.
test_fused_recurrent already skips D > 128 on Intel Alchemist, but this new variant drops that guard. That makes the transpose-only test fail on a known unsupported backend instead of exercising the new layout path.
Suggested fix
def test_fused_recurrent_transpose_state(
B: int,
T: int,
H: int,
D: int,
scale: float,
gate_logit_normalizer: float,
dtype: torch.dtype,
):
torch.manual_seed(42)
+ if IS_INTEL_ALCHEMIST and D > 128:
+ pytest.skip(reason="fused_recurrent_kda is not supported on alchemist for D>128")
q = torch.rand(B, T, H, D, dtype=dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| torch.manual_seed(42) | |
| q = torch.rand(B, T, H, D, dtype=dtype) | |
| k = torch.rand(B, T, H, D, dtype=dtype) | |
| v = torch.rand(B, T, H, D, dtype=dtype) | |
| g = F.logsigmoid(torch.randn(B, T, H, D, dtype=torch.float)) / gate_logit_normalizer | |
| beta = torch.randn(B, T, H, dtype=dtype).sigmoid() | |
| h0_kv = torch.randn(B, H, D, D, dtype=torch.float32) | |
| h0_vk = h0_kv.transpose(-1, -2).contiguous() | |
| q, k, v, g, beta, h0_kv, h0_vk = map(lambda x: x.to(device), (q, k, v, g, beta, h0_kv, h0_vk)) | |
| torch.manual_seed(42) | |
| if IS_INTEL_ALCHEMIST and D > 128: | |
| pytest.skip(reason="fused_recurrent_kda is not supported on alchemist for D>128") | |
| q = torch.rand(B, T, H, D, dtype=dtype) | |
| k = torch.rand(B, T, H, D, dtype=dtype) | |
| v = torch.rand(B, T, H, D, dtype=dtype) | |
| g = F.logsigmoid(torch.randn(B, T, H, D, dtype=torch.float)) / gate_logit_normalizer | |
| beta = torch.randn(B, T, H, dtype=dtype).sigmoid() | |
| h0_kv = torch.randn(B, H, D, D, dtype=torch.float32) | |
| h0_vk = h0_kv.transpose(-1, -2).contiguous() | |
| q, k, v, g, beta, h0_kv, h0_vk = map(lambda x: x.to(device), (q, k, v, g, beta, h0_kv, h0_vk)) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/ops/test_kda.py` around lines 163 - 171, This transpose-only test is
missing the Intel Alchemist skip that test_fused_recurrent uses; add the same
guard so the test calls pytest.skip when running on Alchemist with D > 128.
Locate the transpose test in tests/ops/test_kda.py (same scope as the
q,k,v,g,beta,h0_kv,h0_vk setup), import or reuse the existing is_alchemist()
helper and add a conditional like `if is_alchemist() and D > 128:
pytest.skip(...)` before creating tensors so the test mirrors
test_fused_recurrent's behavior.
b574148 to
d4b4909
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
fla/ops/common/chunk_delta_h.py (1)
727-748:⚠️ Potential issue | 🟠 MajorUse
chunk_sizeforBThere.Line 734 still hardcodes
BT = 64. If the forward pass used any otherchunk_size, this backward wrapper will derive a differentNTand state layout than the tensors it is consuming.🐛 Proposed fix
- BT = 64 + BT = chunk_sizeIf 64 is the only supported value, assert that explicitly and drop the parameter instead.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/common/chunk_delta_h.py` around lines 727 - 748, The code currently hardcodes BT = 64 which can mismatch the provided chunk_size and lead to incorrect NT and state layouts; replace the hardcoded BT with BT = chunk_size (or if only 64 is supported, assert chunk_size == 64 and remove the parameter) so NT and chunk_offsets are computed consistently; update any dependent logic that uses BT (e.g., prepare_chunk_offsets(cu_seqlens, BT), NT computation, dh shape creation using NT, and dh0) to rely on the corrected BT value and keep the existing transpose_state_layout, dh, dh0, cu_seqlens, chunk_indices handling unchanged.fla/ops/kda/fused_recurrent.py (1)
144-167:⚠️ Potential issue | 🔴 CriticalMask
gloads on padded K lanes.
BKis rounded up withtriton.next_power_of_2(K), butp_gis loaded withoutmask_kat line 148. For non-power-of-two head sizes, this reads past the end of eachgrow, and those unmasked padded lanes flow directly intoexp(b_gk)and corrupt the hidden stateb_hvia the multiplication at lines 157–159.Proposed fix
- b_g = tl.load(p_g, eviction_policy='evict_last').to(tl.float32) + b_g = tl.load(p_g, mask=mask_k, other=0, eviction_policy='evict_last').to(tl.float32)
♻️ Duplicate comments (4)
tests/ops/test_gated_delta.py (1)
190-192:⚠️ Potential issue | 🟡 MinorUse non-zero initial state to properly test transpose layout.
h0_kvis initialized withtorch.zeros, which masks layout bugs since both correct and incorrect transposed loads produce the same result (zeros). Usetorch.randnliketest_fused_recurrent_transpose_statedoes at line 266.Suggested fix
- h0_kv = torch.zeros(B, H, D, D, dtype=torch.float32) + h0_kv = torch.randn(B, H, D, D, dtype=torch.float32) h0_vk = h0_kv.transpose(-1, -2).contiguous()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_gated_delta.py` around lines 190 - 192, The initial state h0_kv is incorrectly initialized with zeros which hides transpose/layout bugs; replace its initialization with a non-zero random tensor (use torch.randn with same shape and dtype) so h0_kv = torch.randn(B, H, D, D, dtype=torch.float32), then compute h0_vk = h0_kv.transpose(-1, -2).contiguous() and keep the existing map(... .to(device).requires_grad_(True)) call for (q, k, v, beta, g, h0_kv, h0_vk) to ensure gradients and device placement remain the same.tests/ops/test_kda.py (1)
163-171:⚠️ Potential issue | 🟡 MinorAdd Intel Alchemist skip guard for consistency.
test_fused_recurrentat line 102-103 skips whenIS_INTEL_ALCHEMIST and D > 128, but this transpose test is missing that guard. This could cause failures on Alchemist hardware instead of exercising the transpose layout path.Suggested fix
def test_fused_recurrent_transpose_state( B: int, T: int, H: int, D: int, scale: float, gate_logit_normalizer: float, dtype: torch.dtype, ): torch.manual_seed(42) + if IS_INTEL_ALCHEMIST and D > 128: + pytest.skip(reason="fused_recurrent_kda is not supported on alchemist for D>128") q = torch.rand(B, T, H, D, dtype=dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_kda.py` around lines 163 - 171, Add the same Intel Alchemist skip guard used in test_fused_recurrent to this transpose test: check IS_INTEL_ALCHEMIST and D > 128 at the start of the test (before creating q,k,v,g,beta,h0_kv/h0_vk) and call pytest.skip(...) when true so the transpose-layout path is not executed on Alchemist hardware; reference the symbols IS_INTEL_ALCHEMIST and D and mirror the skip condition/behavior from test_fused_recurrent.fla/ops/gated_delta_rule/fused_recurrent.py (1)
171-187:⚠️ Potential issue | 🟠 MajorAt least validate
initial_stateagainst the selected state shape.
transpose_state_layoutnow changes the state contract here, but this wrapper still accepts anyinitial_state. The public entry point only validates the sequence count, so incompatibleHVor tail dims can still reach the kernel's hard-coded pointer math and be silently misread.🛡️ Proposed fix
B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] N = B if cu_seqlens is None else len(cu_seqlens) - 1 + if initial_state is not None: + expected_state_shape = (N, HV, V, K) if transpose_state_layout else (N, HV, K, V) + if tuple(initial_state.shape) != expected_state_shape: + raise ValueError( + f"`initial_state` must have shape {expected_state_shape} when " + f"`transpose_state_layout={transpose_state_layout}`; got {tuple(initial_state.shape)}." + ) BK = triton.next_power_of_2(K)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 171 - 187, The wrapper must validate any provided initial_state against the expected layout determined by transpose_state_layout before using it: in the function (fused_recurrent / where transpose_state_layout, HV, K, V, N are computed) check that when initial_state is not None its shape equals (N, HV, V, K) if transpose_state_layout is True, otherwise equals (N, HV, K, V), and raise a clear ValueError if mismatched (include expected vs actual shape); also validate its dtype/device are compatible with q/v and that the leading dimension N matches len(cu_seqlens)-1 when cu_seqlens is provided. Ensure this validation occurs before any pointer/stride math or kernel dispatch that assumes the state layout.fla/ops/kda/fused_recurrent.py (1)
239-262:⚠️ Potential issue | 🟠 MajorAt least validate
initial_stateagainst the selected state shape.This wrapper now flips the state contract between
[N, HV, K, V]and[N, HV, V, K], but it still forwards anyinitial_statestraight into the kernel and, in the inplace path, reuses it asfinal_state. The varlen entry point only checksshape[0], so incompatibleHVor tail dims can still be silently misread.🛡️ Proposed fix
B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] N = B if cu_seqlens is None else len(cu_seqlens) - 1 + if initial_state is not None: + expected_state_shape = (N, HV, V, K) if transpose_state_layout else (N, HV, K, V) + if tuple(initial_state.shape) != expected_state_shape: + raise ValueError( + f"`initial_state` must have shape {expected_state_shape} when " + f"`transpose_state_layout={transpose_state_layout}`; got {tuple(initial_state.shape)}." + ) BK = triton.next_power_of_2(K)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/fused_recurrent.py` around lines 239 - 262, The wrapper flips state layout between [N, HV, K, V] and [N, HV, V, K] based on transpose_state_layout but never validates initial_state against the selected layout; add explicit shape validation where initial_state is accepted or reused (symbols: initial_state, final_state, transpose_state_layout, inplace_final_state, output_final_state) — check that initial_state.ndim and each dimension (N, HV, K, V or V,K order) match the expected shape and dtype, and raise a clear ValueError/Assertion if they don’t; also when inplace_final_state is true, ensure final_state (== initial_state) exactly matches the expected layout before using it, and perform the same validation for any varlen entry-point that previously only checked shape[0] so HV and tail dimensions cannot be silently misread.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 114-126: The loads for p_gk and p_gv are not masked, so padded
lanes read garbage and corrupt b_h via exp(...) multiplications; update the
tl.load calls that produce b_gk and b_gv to use the same padding mask used for
q/k/v/beta loads (i.e., pass the mask and mask_fill value) so out-of-range lanes
are zeroed before computing exp and multiplying into b_h, keeping the existing
TRANSPOSE_STATE branching and use of USE_GK/USE_GV and symbols p_gk, p_gv, b_gk,
b_gv, b_h unchanged.
---
Outside diff comments:
In `@fla/ops/common/chunk_delta_h.py`:
- Around line 727-748: The code currently hardcodes BT = 64 which can mismatch
the provided chunk_size and lead to incorrect NT and state layouts; replace the
hardcoded BT with BT = chunk_size (or if only 64 is supported, assert chunk_size
== 64 and remove the parameter) so NT and chunk_offsets are computed
consistently; update any dependent logic that uses BT (e.g.,
prepare_chunk_offsets(cu_seqlens, BT), NT computation, dh shape creation using
NT, and dh0) to rely on the corrected BT value and keep the existing
transpose_state_layout, dh, dh0, cu_seqlens, chunk_indices handling unchanged.
---
Duplicate comments:
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 171-187: The wrapper must validate any provided initial_state
against the expected layout determined by transpose_state_layout before using
it: in the function (fused_recurrent / where transpose_state_layout, HV, K, V, N
are computed) check that when initial_state is not None its shape equals (N, HV,
V, K) if transpose_state_layout is True, otherwise equals (N, HV, K, V), and
raise a clear ValueError if mismatched (include expected vs actual shape); also
validate its dtype/device are compatible with q/v and that the leading dimension
N matches len(cu_seqlens)-1 when cu_seqlens is provided. Ensure this validation
occurs before any pointer/stride math or kernel dispatch that assumes the state
layout.
In `@fla/ops/kda/fused_recurrent.py`:
- Around line 239-262: The wrapper flips state layout between [N, HV, K, V] and
[N, HV, V, K] based on transpose_state_layout but never validates initial_state
against the selected layout; add explicit shape validation where initial_state
is accepted or reused (symbols: initial_state, final_state,
transpose_state_layout, inplace_final_state, output_final_state) — check that
initial_state.ndim and each dimension (N, HV, K, V or V,K order) match the
expected shape and dtype, and raise a clear ValueError/Assertion if they don’t;
also when inplace_final_state is true, ensure final_state (== initial_state)
exactly matches the expected layout before using it, and perform the same
validation for any varlen entry-point that previously only checked shape[0] so
HV and tail dimensions cannot be silently misread.
In `@tests/ops/test_gated_delta.py`:
- Around line 190-192: The initial state h0_kv is incorrectly initialized with
zeros which hides transpose/layout bugs; replace its initialization with a
non-zero random tensor (use torch.randn with same shape and dtype) so h0_kv =
torch.randn(B, H, D, D, dtype=torch.float32), then compute h0_vk =
h0_kv.transpose(-1, -2).contiguous() and keep the existing map(...
.to(device).requires_grad_(True)) call for (q, k, v, beta, g, h0_kv, h0_vk) to
ensure gradients and device placement remain the same.
In `@tests/ops/test_kda.py`:
- Around line 163-171: Add the same Intel Alchemist skip guard used in
test_fused_recurrent to this transpose test: check IS_INTEL_ALCHEMIST and D >
128 at the start of the test (before creating q,k,v,g,beta,h0_kv/h0_vk) and call
pytest.skip(...) when true so the transpose-layout path is not executed on
Alchemist hardware; reference the symbols IS_INTEL_ALCHEMIST and D and mirror
the skip condition/behavior from test_fused_recurrent.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 264350ca-243a-4ce2-8571-444813ef9859
📒 Files selected for processing (16)
fla/ops/common/backends/intracard.pyfla/ops/common/chunk_delta_h.pyfla/ops/common/chunk_o.pyfla/ops/common/intracard_cp.pyfla/ops/cp/chunk_delta_h.pyfla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/fused_recurrent.pyfla/ops/gla/chunk.pyfla/ops/kda/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_fwd.pyfla/ops/kda/fused_recurrent.pytests/context_parallel/test_cp_gdn.pytests/context_parallel/test_cp_kda.pytests/ops/test_gated_delta.pytests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/context_parallel/test_cp_kda.py
- fla/ops/kda/chunk.py
| if USE_GK: | ||
| b_gk = tl.load(p_gk).to(tl.float32) | ||
| b_h *= exp(b_gk[:, None]) | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp(b_gk[None, :]) | ||
| else: | ||
| b_h *= exp(b_gk[:, None]) | ||
|
|
||
| if USE_GV: | ||
| b_gv = tl.load(p_gv).to(tl.float32) | ||
| b_h *= exp(b_gv[None, :]) | ||
|
|
||
| b_v = b_beta * (b_v - tl.sum(b_h * b_k[:, None], 0)) | ||
| b_h += b_k[:, None] * b_v | ||
|
|
||
| # [BV] | ||
| b_o = tl.sum(b_h * b_q[:, None], 0) | ||
| if TRANSPOSE_STATE: | ||
| b_h *= exp(b_gv[:, None]) | ||
| else: | ||
| b_h *= exp(b_gv[None, :]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
head -c 1000 "fla/ops/gated_delta_rule/fused_recurrent.py" && echo -e "\n...[file exists]"Repository: fla-org/flash-linear-attention
Length of output: 1091
🏁 Script executed:
wc -l fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 121
🏁 Script executed:
sed -n '100,140p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 1459
🏁 Script executed:
sed -n '170,195p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 906
🏁 Script executed:
# Search for mask_k and mask_v in the kernel to see if they are defined
rg -n "mask_k|mask_v" fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 578
🏁 Script executed:
# Check the broader kernel signature and how other loads are masked
sed -n '1,100p' fla/ops/gated_delta_rule/fused_recurrent.py | grep -n "def \|@triton\|mask\|load"Repository: fla-org/flash-linear-attention
Length of output: 721
🏁 Script executed:
# Find where p_gk and p_gv pointers are initialized in the kernel
sed -n '50,90p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 1260
🏁 Script executed:
# Check how gk and gv dimensions are used and whether they are padded or not
sed -n '60,85p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 785
🏁 Script executed:
# Look for pointer arithmetic with gk and gv to understand their shape
rg -n "p_gk|p_gv|gk\[|gv\[" fla/ops/gated_delta_rule/fused_recurrent.py | head -20Repository: fla-org/flash-linear-attention
Length of output: 521
🏁 Script executed:
# Search for how g, gk, gv dimensions are computed and passed to the kernel
sed -n '200,250p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 1468
🏁 Script executed:
# Look at the full kernel function signature to see gk, gv parameter types
sed -n '20,75p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 1495
🏁 Script executed:
# Check if there are any input shape validations that enforce K/V to be powers of two
sed -n '170,200p' fla/ops/gated_delta_rule/fused_recurrent.py | grep -A 10 "def fused_recurrent_gated_delta_rule_fwd"Repository: fla-org/flash-linear-attention
Length of output: 56
🏁 Script executed:
# Look for any assertions or constraints on K/V dimensions
rg -n "assert|constraint|power_of_2" fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 204
🏁 Script executed:
# Check the actual load patterns - confirm they use mask or don't use mask
sed -n '95,130p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 1346
🏁 Script executed:
# Verify the pointer arithmetic for all loads to ensure consistency
sed -n '65,75p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 416
🏁 Script executed:
# Look at how b_g is loaded to confirm it's a scalar (no K dimension)
sed -n '110,120p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 368
🏁 Script executed:
# Double-check the full context of pointer initialization for all gate tensors
sed -n '65,85p' fla/ops/gated_delta_rule/fused_recurrent.pyRepository: fla-org/flash-linear-attention
Length of output: 639
Mask gk and gv loads on padded lanes.
When K or V are not powers of two, BK and BV are rounded up, but p_gk and p_gv are loaded without masking. This reads garbage values into the padded lanes, which then corrupt the state computation through the exp(...) multiplications. All other dimension-dependent loads (q, k, v, beta) correctly use masks; gk and gv must too.
🐛 Proposed fix
- b_gk = tl.load(p_gk).to(tl.float32)
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)- b_gv = tl.load(p_gv).to(tl.float32)
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if USE_GK: | |
| b_gk = tl.load(p_gk).to(tl.float32) | |
| b_h *= exp(b_gk[:, None]) | |
| if TRANSPOSE_STATE: | |
| b_h *= exp(b_gk[None, :]) | |
| else: | |
| b_h *= exp(b_gk[:, None]) | |
| if USE_GV: | |
| b_gv = tl.load(p_gv).to(tl.float32) | |
| b_h *= exp(b_gv[None, :]) | |
| b_v = b_beta * (b_v - tl.sum(b_h * b_k[:, None], 0)) | |
| b_h += b_k[:, None] * b_v | |
| # [BV] | |
| b_o = tl.sum(b_h * b_q[:, None], 0) | |
| if TRANSPOSE_STATE: | |
| b_h *= exp(b_gv[:, None]) | |
| else: | |
| b_h *= exp(b_gv[None, :]) | |
| if USE_GK: | |
| b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) | |
| if TRANSPOSE_STATE: | |
| b_h *= exp(b_gk[None, :]) | |
| else: | |
| b_h *= exp(b_gk[:, None]) | |
| if USE_GV: | |
| b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) | |
| if TRANSPOSE_STATE: | |
| b_h *= exp(b_gv[:, None]) | |
| else: | |
| b_h *= exp(b_gv[None, :]) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 114 - 126, The
loads for p_gk and p_gv are not masked, so padded lanes read garbage and corrupt
b_h via exp(...) multiplications; update the tl.load calls that produce b_gk and
b_gv to use the same padding mask used for q/k/v/beta loads (i.e., pass the mask
and mask_fill value) so out-of-range lanes are zeroed before computing exp and
multiplying into b_h, keeping the existing TRANSPOSE_STATE branching and use of
USE_GK/USE_GV and symbols p_gk, p_gv, b_gk, b_gv, b_h unchanged.
Add transpose_state_layout parameter to chunk, fused_recurrent, and context parallel paths for both KDA and GDN. When enabled, all state tensors use [V,K] layout instead of [K,V] to improve memory access patterns.
d4b4909 to
283a560
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (5)
fla/ops/gated_delta_rule/chunk.py (1)
313-313:⚠️ Potential issue | 🟠 MajorValidate
initial_stateagainsttranspose_state_layout.This flag changes the expected cache tail from
[H, K, V]to[H, V, K], but the wrapper still only checks the batch dimension. Passing an old[N, H, K, V]state withtranspose_state_layout=Truewill be silently misread by the Triton kernels.🩹 Suggested fix
if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing.", ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.", ) + if initial_state is not None: + expected_tail = ( + (q.shape[2], v.shape[-1], k.shape[-1]) + if transpose_state_layout + else (q.shape[2], k.shape[-1], v.shape[-1]) + ) + if initial_state.ndim != 4 or tuple(initial_state.shape[1:]) != expected_tail: + raise ValueError( + f"`initial_state` must have shape [N, {expected_tail[0]}, {expected_tail[1]}, {expected_tail[2]}] " + f"when transpose_state_layout={transpose_state_layout}, got {tuple(initial_state.shape)}." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/chunk.py` at line 313, The wrapper must validate that the incoming initial_state shape matches transpose_state_layout: when transpose_state_layout is False expect cache tail layout [N, H, K, V] and when True expect [N, H, V, K]; update the code that handles the initial_state (referencing the transpose_state_layout flag and the initial_state variable/parameter in the gated-delta/chunk wrapper) to check the dimensionality and the order of the last two dims and raise a clear error if they mismatch (include expected vs actual shapes in the message) so an old [N, H, K, V] passed with transpose_state_layout=True is detected rather than silently misread by the Triton kernels.fla/ops/gated_delta_rule/fused_recurrent.py (2)
171-185:⚠️ Potential issue | 🟠 MajorReject mismatched
initial_statelayouts in transpose mode.This path now allocates
[N, HV, V, K]final states whentranspose_state_layout=True, but it still accepts anyinitial_statetail shape. A stale[N, HV, K, V]cache will be silently misread by the kernel.🩹 Suggested fix
B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] N = B if cu_seqlens is None else len(cu_seqlens) - 1 BK = triton.next_power_of_2(K) BV = min(8, triton.next_power_of_2(V)) if gv is None else triton.next_power_of_2(V) NV = triton.cdiv(V, BV) + + if initial_state is not None: + expected_tail = (HV, V, K) if transpose_state_layout else (HV, K, V) + if initial_state.ndim != 4 or tuple(initial_state.shape[1:]) != expected_tail: + raise ValueError( + f"`initial_state` must have shape [N, {expected_tail[0]}, {expected_tail[1]}, {expected_tail[2]}] " + f"when transpose_state_layout={transpose_state_layout}, got {tuple(initial_state.shape)}." + ) o = torch.empty_like(v)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 171 - 185, When transpose_state_layout=True, validate that any provided initial_state has the transposed tail layout [N, HV, V, K] (not [N, HV, K, V]) and raise a clear ValueError if it does not; locate the logic around transpose_state_layout, initial_state, and final_state in fused_recurrent.py (the block that allocates final_state when output_final_state is true) and add a shape/layout check before using or allocating final_state so a stale [N, HV, K, V] cache is rejected rather than silently misread by the kernel.
114-126:⚠️ Potential issue | 🔴 CriticalMask
gk/gvloads on padded lanes.
BKandBVroundK/Vup, butp_gkandp_gvare still loaded without masks. On non-power-of-two heads, the padded lanes feed garbage into theexp(...)multipliers and corrupt the state update.🩹 Suggested fix
if USE_GK: - b_gk = tl.load(p_gk).to(tl.float32) + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) if TRANSPOSE_STATE: b_h *= exp(b_gk[None, :]) else: b_h *= exp(b_gk[:, None]) if USE_GV: - b_gv = tl.load(p_gv).to(tl.float32) + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) if TRANSPOSE_STATE: b_h *= exp(b_gv[:, None]) else: b_h *= exp(b_gv[None, :])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/gated_delta_rule/fused_recurrent.py` around lines 114 - 126, The loads for p_gk/p_gv should be masked to avoid reading padded lanes (which BK/BV round up) and injecting garbage into exp multipliers; modify the p_gk and p_gv loads so that tl.load (or equivalent) is passed a mask that zeros out indices beyond the true head size (or load full then set masked entries to 0), and ensure the mask shape matches the TRANSPOSE_STATE branching (i.e., shape the mask as [None, :] vs [:, None] to match how b_gk/b_gv are broadcast into b_h); apply this for both USE_GK (b_gk from p_gk) and USE_GV (b_gv from p_gv) before calling exp and multiplying into b_h.fla/ops/kda/fused_recurrent.py (1)
239-262:⚠️ Potential issue | 🟠 MajorFail fast on recurrent-state layout mismatches.
When
transpose_state_layout=True, this path switches the kernel and allocated state to[N, HV, V, K], but a providedinitial_stateis still accepted unchecked. Reusing an older[N, HV, K, V]cache will be silently reinterpreted.🩹 Suggested fix
B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] N = B if cu_seqlens is None else len(cu_seqlens) - 1 BK = triton.next_power_of_2(K) BV = 32 + + if initial_state is not None: + expected_tail = (HV, V, K) if transpose_state_layout else (HV, K, V) + if initial_state.ndim != 4 or tuple(initial_state.shape[1:]) != expected_tail: + raise ValueError( + f"`initial_state` must have shape [N, {expected_tail[0]}, {expected_tail[1]}, {expected_tail[2]}] " + f"when transpose_state_layout={transpose_state_layout}, got {tuple(initial_state.shape)}." + ) if out is None: out = torch.zeros_like(v)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fla/ops/kda/fused_recurrent.py` around lines 239 - 262, The code allows an initial_state whose memory layout doesn't match transpose_state_layout, causing silent reinterpretation; add a fast-fail shape check before using initial_state: compute N, HV, K, V as in the function and if initial_state is not None assert (or raise ValueError) that initial_state.shape equals (N, HV, V, K) when transpose_state_layout is True and equals (N, HV, K, V) when transpose_state_layout is False, with a clear error message referencing transpose_state_layout, initial_state and expected shape; apply this check around the branch that sets final_state (the logic using transpose_state_layout, inplace_final_state, output_final_state and final_state) so mismatched caches are rejected immediately.tests/ops/test_kda.py (1)
163-171:⚠️ Potential issue | 🟡 MinorAdd Intel Alchemist skip guard for consistency.
This test is missing the
IS_INTEL_ALCHEMISTguard thattest_fused_recurrentuses. Without it, the test will fail on Alchemist GPUs with D > 128 instead of being skipped.Suggested fix
def test_fused_recurrent_transpose_state( B: int, T: int, H: int, D: int, scale: float, gate_logit_normalizer: float, dtype: torch.dtype, ): torch.manual_seed(42) + if IS_INTEL_ALCHEMIST and D > 128: + pytest.skip(reason="fused_recurrent_kda is not supported on alchemist for D>128") q = torch.rand(B, T, H, D, dtype=dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/ops/test_kda.py` around lines 163 - 171, This test block lacks the IS_INTEL_ALCHEMIST guard that causes tests to skip on Alchemist GPUs when D > 128; add the same check used in test_fused_recurrent: before creating tensors (before torch.manual_seed(42)), check if IS_INTEL_ALCHEMIST and D > 128 and call pytest.skip with a short message, and ensure IS_INTEL_ALCHEMIST (and pytest if not already imported) is available in the test module so the guard can be applied exactly where q, k, v, g, beta, h0_kv, h0_vk are constructed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@fla/ops/gated_delta_rule/chunk.py`:
- Line 313: The wrapper must validate that the incoming initial_state shape
matches transpose_state_layout: when transpose_state_layout is False expect
cache tail layout [N, H, K, V] and when True expect [N, H, V, K]; update the
code that handles the initial_state (referencing the transpose_state_layout flag
and the initial_state variable/parameter in the gated-delta/chunk wrapper) to
check the dimensionality and the order of the last two dims and raise a clear
error if they mismatch (include expected vs actual shapes in the message) so an
old [N, H, K, V] passed with transpose_state_layout=True is detected rather than
silently misread by the Triton kernels.
In `@fla/ops/gated_delta_rule/fused_recurrent.py`:
- Around line 171-185: When transpose_state_layout=True, validate that any
provided initial_state has the transposed tail layout [N, HV, V, K] (not [N, HV,
K, V]) and raise a clear ValueError if it does not; locate the logic around
transpose_state_layout, initial_state, and final_state in fused_recurrent.py
(the block that allocates final_state when output_final_state is true) and add a
shape/layout check before using or allocating final_state so a stale [N, HV, K,
V] cache is rejected rather than silently misread by the kernel.
- Around line 114-126: The loads for p_gk/p_gv should be masked to avoid reading
padded lanes (which BK/BV round up) and injecting garbage into exp multipliers;
modify the p_gk and p_gv loads so that tl.load (or equivalent) is passed a mask
that zeros out indices beyond the true head size (or load full then set masked
entries to 0), and ensure the mask shape matches the TRANSPOSE_STATE branching
(i.e., shape the mask as [None, :] vs [:, None] to match how b_gk/b_gv are
broadcast into b_h); apply this for both USE_GK (b_gk from p_gk) and USE_GV
(b_gv from p_gv) before calling exp and multiplying into b_h.
In `@fla/ops/kda/fused_recurrent.py`:
- Around line 239-262: The code allows an initial_state whose memory layout
doesn't match transpose_state_layout, causing silent reinterpretation; add a
fast-fail shape check before using initial_state: compute N, HV, K, V as in the
function and if initial_state is not None assert (or raise ValueError) that
initial_state.shape equals (N, HV, V, K) when transpose_state_layout is True and
equals (N, HV, K, V) when transpose_state_layout is False, with a clear error
message referencing transpose_state_layout, initial_state and expected shape;
apply this check around the branch that sets final_state (the logic using
transpose_state_layout, inplace_final_state, output_final_state and final_state)
so mismatched caches are rejected immediately.
In `@tests/ops/test_kda.py`:
- Around line 163-171: This test block lacks the IS_INTEL_ALCHEMIST guard that
causes tests to skip on Alchemist GPUs when D > 128; add the same check used in
test_fused_recurrent: before creating tensors (before torch.manual_seed(42)),
check if IS_INTEL_ALCHEMIST and D > 128 and call pytest.skip with a short
message, and ensure IS_INTEL_ALCHEMIST (and pytest if not already imported) is
available in the test module so the guard can be applied exactly where q, k, v,
g, beta, h0_kv, h0_vk are constructed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 39ec030b-a624-4bef-b850-d1e1a77bf52b
📒 Files selected for processing (16)
fla/ops/common/backends/intracard.pyfla/ops/common/chunk_delta_h.pyfla/ops/common/chunk_o.pyfla/ops/common/intracard_cp.pyfla/ops/cp/chunk_delta_h.pyfla/ops/gated_delta_rule/chunk.pyfla/ops/gated_delta_rule/fused_recurrent.pyfla/ops/gla/chunk.pyfla/ops/kda/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_fwd.pyfla/ops/kda/fused_recurrent.pytests/context_parallel/test_cp_gdn.pytests/context_parallel/test_cp_kda.pytests/ops/test_gated_delta.pytests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (2)
- fla/ops/kda/chunk_bwd.py
- tests/context_parallel/test_cp_kda.py
Add transpose_state_layout parameter to chunk, fused_recurrent, and context parallel paths for both KDA and GDN. When enabled, all state tensors use [V,K] layout instead of [K,V] to improve memory access patterns.
Summary by CodeRabbit
New Features
Tests