Skip to content

[KDA] fused bwd kernels inter and prepare wy#688

Merged
yzhangcs merged 16 commits intomainfrom
feature/fuse_bwd_kernels
Dec 20, 2025
Merged

[KDA] fused bwd kernels inter and prepare wy#688
yzhangcs merged 16 commits intomainfrom
feature/fuse_bwd_kernels

Conversation

@Nathancgy
Copy link
Copy Markdown
Contributor

@Nathancgy Nathancgy commented Dec 18, 2025

8 * 4k nsys profile:
mean (ns)
fused: 1151008.7
not fused:
780636.6 + 652113.3

This PR fuses chunk_kda_bwd_kernel_inter and prepare_wy_repr_bwd_kda_kernel into a single kernel to reduce memory bandwidth usage in the backward pass.

Previously, the backward pass computed inter-chunk gradients in one kernel (writing dw, dk, dg to global memory), then immediately read them back in the WY backward kernel. The fused kernel eliminates this redundant memory traffic by keeping dw in registers and computing the WY backward contributions in the same kernel launch.

Key changes:

  • Added chunk_kda_bwd_kernel_inter_wy_fused that combines both kernels with an interleaved K-block loop design
  • dw never hits global memory (stays in registers)
  • dk and dg are written once instead of write-read-write
  • Eliminates one kernel launch overhead

Summary by CodeRabbit

  • Refactor
    • Reworked and fused the attention backward path for more efficient block-level computation and dataflow.
    • Public call signature remains effectively unchanged, but observable outputs now include additional gradient components (dv, db, dA).
    • Removed an intermediate representation step and consolidated gradient propagation across the fused path for streamlined backward execution.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 18, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Replaces the previous KDA backward path with a fused WY+dqkg kernel/wrapper that accepts v_new, beta, and A, returns db and dA (and propagates dv), removes the prepare_wy_repr_bwd step, and updates intra-chunk backward wiring to consume dAkk.

Changes

Cohort / File(s) Summary
Fused backward kernel & wrapper
fla/ops/kda/chunk_bwd.py
Renamed and rewrote the fused kernel & wrapper: chunk_kda_bwd_kernel_dqkwgchunk_kda_bwd_kernel_wy_dqkg_fused, wrapper chunk_kda_bwd_wy_dqkg_fused. Signatures extended to accept v_new, beta, and A; new outputs db and dA added; internal path uses dv2 and returns (dq, dk, dv, db, dg, dA). Grid/indexing and per-block pointer wiring updated to compute/accumulate dA and db.
Backward call site & imports
fla/ops/kda/chunk.py
Updated imports to use chunk_kda_bwd_dAv and chunk_kda_bwd_wy_dqkg_fused; removed prepare_wy_repr_bwd import. chunk_kda_bwd now calls the fused wrapper with v, v_new, beta, A=Akk, propagates dv, and receives dAkk. chunk_kda_bwd_intra now consumes dAqk=dq and dAkk; observable outputs include db and dAkk while keeping the overall return shape (dq, dk, dv, db, dg, dh0).

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

  • Review focus:
    • Kernel grid/program_id/NT and indexing changes in fla/ops/kda/chunk_bwd.py.
    • Correct accumulation, memory loads/stores and pointer wiring for dA and db.
    • Call-site argument passing (v_new, beta, A=Akk) and downstream handling of dAkk in fla/ops/kda/chunk.py.

Possibly related PRs

  • [KDA] Fuse dAqk and dv #689 — Directly related earlier work adding chunk_kda_bwd_dAv and modifying chunk backward kernels/wrappers that this change continues to evolve.
  • Add KDA #621 — Prior changes to the dqkwg kernel/wrapper and WY handling that align with the new fused API.
  • [KDA] Fuse inplace add #682 — Related changes to WY backward wiring and intermediate gradient buffer handling tied to removal of prepare_wy_repr_bwd.

Poem

🐰 I hopped through kernels, swift and bright,
I fused WY paths by moonlit byte,
dA and db now tumble free,
Pointers tidy, loops agree,
A carrot patch of code—delight! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title references 'fused bwd kernels inter and prepare wy', which aligns with the PR's objective of fusing chunk_kda_bwd_kernel_inter and prepare_wy_repr_bwd_kda_kernel into a single kernel (chunk_kda_bwd_kernel_inter_wy_fused). The title captures the main technical change.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feature/fuse_bwd_kernels

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Nathancgy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly optimizes the backward pass for KDA operations by fusing two previously separate kernels into a single, more efficient kernel. This fusion strategy directly addresses memory bandwidth bottlenecks and reduces kernel launch overhead, leading to improved performance. The core idea is to keep intermediate gradients in registers, avoiding costly global memory access, and streamline the computation of other gradients by eliminating redundant write-read cycles.

Highlights

  • Kernel Fusion: Fusing chunk_kda_bwd_kernel_inter and prepare_wy_repr_bwd_kda_kernel into a single kernel, chunk_kda_bwd_kernel_inter_wy_fused, to optimize the backward pass.
  • Memory Bandwidth Reduction: Eliminated redundant memory traffic by keeping dw in registers, preventing it from being written to and read from global memory.
  • Reduced Kernel Launches: Consolidated two separate kernel launches into one, reducing overhead and improving performance.
  • Optimized Gradient Writes: dk and dg are now written to global memory only once, instead of a write-read-write cycle.
  • Performance Improvement: Achieved a performance gain, with the fused kernel showing a mean execution time of 1151008.7 ns compared to 780636.6 + 652113.3 ns (total 1432749.9 ns) for the unfused kernels.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization by fusing two backward pass kernels, chunk_kda_bwd_kernel_inter and prepare_wy_repr_bwd_kda_kernel, into a single chunk_kda_bwd_kernel_inter_wy_fused kernel. This is a great approach to reduce memory bandwidth by keeping intermediate gradients like dw in registers. The changes in chunk.py and the removal of obsolete code in wy_fast.py are well-executed. However, my review of the new fused Triton kernel in chunk_inter.py has identified several critical issues related to incorrect memory addressing and gradient calculations that will lead to incorrect results. These issues need to be addressed to ensure the correctness of the backward pass.

Comment thread fla/ops/kda/chunk_bwd.py
p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))

p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tl.make_block_ptr for matrix A is configured incorrectly. The shape and strides seem to be for a transposed view, and the start offset calculation is incorrect. This will lead to out-of-bounds memory access and incorrect results. The shape of A for a given head is (T, BT), so the pointer should be configured accordingly.

Suggested change
p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

Comment thread fla/ops/kda/chunk_bwd.py Outdated
b_dw_neg_cast = b_dw_neg.to(b_A.dtype)
b_dA += tl.dot(b_dw_neg_cast, tl.trans(b_kbg))

b_dkbg = tl.dot(b_A, b_dw_neg_cast)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The gradient calculation for b_dkbg is incorrect. The forward pass involves a multiplication with A (which represents Akk_inv), so the backward pass requires multiplication with A.T. The code is missing the transpose on b_A.

Suggested change
b_dkbg = tl.dot(b_A, b_dw_neg_cast)
b_dkbg = tl.dot(tl.trans(b_A), b_dw_neg_cast)

Comment thread fla/ops/kda/chunk_bwd.py Outdated

b_dA += tl.dot(b_du, tl.trans(b_vb))

b_dvb = tl.dot(b_A, b_du)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the calculation of b_dkbg, the gradient b_dvb is calculated without transposing b_A. This is incorrect as the backpropagation requires multiplication with the transposed matrix.

Suggested change
b_dvb = tl.dot(b_A, b_du)
b_dvb = tl.dot(tl.trans(b_A), b_du)

Comment thread fla/ops/kda/chunk_bwd.py
Comment on lines +190 to +191
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The gradient calculation for dAkk from dAkk_inv (represented by b_dA) is incorrect. The derivative of a matrix inverse A = X^{-1} is dX = -A^T dA A^T. The current implementation computes A @ (dA @ A), which is mathematically incorrect. The multiplication should be with the transpose of b_A.

Suggested change
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
b_dA = tl.dot(b_dA_t, tl.trans(b_A))

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
fla/ops/kda/chunk_inter.py (1)

231-232: Unused meta parameter in grid function.

The meta parameter is required by Triton's grid API but unused here. This is a known pattern and the static analysis warning can be safely suppressed with a leading underscore.

🔎 Suggested fix:
-    def grid(meta):
+    def grid(_meta):
         return (NT, B * H)
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9714c59 and fb5892b.

📒 Files selected for processing (3)
  • fla/ops/kda/chunk.py (2 hunks)
  • fla/ops/kda/chunk_inter.py (5 hunks)
  • fla/ops/kda/wy_fast.py (0 hunks)
💤 Files with no reviewable changes (1)
  • fla/ops/kda/wy_fast.py
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/kda/chunk.py (2)
fla/ops/kda/chunk_inter.py (1)
  • chunk_kda_bwd_dqkwg_wy_fused (200-261)
fla/ops/kda/wy_fast.py (1)
  • recompute_w_u_fwd (103-147)
🪛 GitHub Actions: lint
fla/ops/kda/chunk_inter.py

[error] 1-1: End-of-file-fixer: Fixed missing newline at end of file.


[error] 1-1: Ruff: 1 issue fixed during pre-commit run (formatting/linting).

🪛 Ruff (0.14.8)
fla/ops/kda/chunk_inter.py

231-231: Unused function argument: meta

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (6)
fla/ops/kda/chunk.py (2)

9-12: LGTM! Import changes align with kernel fusion.

The imports correctly reflect the transition from separate kernels to the fused chunk_kda_bwd_dqkwg_wy_fused, with prepare_wy_repr_bwd appropriately removed since its logic is now integrated into the fused kernel.


140-158: LGTM! Fused kernel integration is well-structured.

The transition to chunk_kda_bwd_dqkwg_wy_fused is correctly implemented:

  • v_org (original v) is passed alongside transformed v_new
  • A=Akk properly supplies the attention matrix
  • Additional output dAkk is correctly propagated to chunk_kda_bwd_intra
  • The explanatory comments clearly document the memory bandwidth optimization rationale
fla/ops/kda/chunk_inter.py (4)

31-60: LGTM! Well-designed fused kernel signature.

The kernel signature properly extends the original with:

  • v_org, beta, A inputs for WY backward computation
  • dv_in for input gradients, db, dA for new outputs
  • Heuristics and autotune configurations appropriately retained

61-95: LGTM! Pointer setup and indexing are correct.

The simplified program ID handling (i_t, i_bh) and pointer arithmetic for all tensors including the new v_org, beta, A, dv_in, db, dA correctly match the expected memory layouts. Variable-length sequence support via IS_VARLEN is properly preserved.


106-170: LGTM! Core computation loop correctly fuses inter-chunk and WY backward logic.

Key observations:

  • b_dw is kept in registers and never written to global memory (per PR objectives)
  • WY backward contributions computed inline (b_dk_wy, b_dg_wy, b_db)
  • b_dk = b_dk_inter + b_dk_wy properly combines both gradient sources
  • b_dA accumulation via tl.dot(b_dw_neg_cast, tl.trans(b_kbg)) is correct

The interleaved K-block loop design eliminates the redundant global memory traffic.


172-197: LGTM! V-dimension loop and dA finalization are correctly implemented.

The V-loop properly accumulates additional b_dA and b_db contributions from v_org. The dA finalization sequence (masking → two tl.dot with transposed b_A → final masking with negation) correctly computes the gradient through the lower-triangular structure.

Comment thread fla/ops/kda/chunk_inter.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
fla/ops/kda/chunk_bwd.py (1)

109-276: Fusion design is sound, but critical correctness issues must be fixed first.

The kernel fusion approach successfully integrates WY backward preparation into the main backward kernel, keeping intermediate gradients in registers. However, the four critical mathematical errors (lines 178, 231, 261, 268-269) will prevent correct gradient computation and must be addressed before this PR can be merged.

Once these issues are resolved, the fused kernel should deliver the performance benefits described in the PR objectives.

♻️ Duplicate comments (4)
fla/ops/kda/chunk_bwd.py (4)

178-179: CRITICAL: Block pointer for matrix A configured incorrectly.

The tl.make_block_ptr configuration has swapped dimensions, strides, and offsets. After the pointer offset at line 163, matrix A has shape [T, BT] with strides (H*BT, 1). The current configuration will cause out-of-bounds memory access and incorrect gradient computation.

🔎 Proposed fix
-    p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))
+    p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

231-231: CRITICAL: Missing transpose in gradient calculation.

The gradient b_dkbg is computed without transposing b_A. Since the forward pass involves A @ x, the backward pass requires A^T @ dx. This error will produce incorrect gradients for the key and beta parameters.

🔎 Proposed fix
-        b_dkbg = tl.dot(b_A, b_dw_neg_cast)
+        b_dkbg = tl.dot(tl.trans(b_A), b_dw_neg_cast)

261-261: CRITICAL: Missing transpose in gradient calculation for dv.

Similar to line 231, the gradient b_dvb is computed without transposing b_A. The backpropagation through A @ u requires A^T @ du, so the transpose is mathematically necessary.

🔎 Proposed fix
-        b_dvb = tl.dot(b_A, b_du)
+        b_dvb = tl.dot(tl.trans(b_A), b_du)

268-269: CRITICAL: Incorrect gradient formula for matrix inverse.

The gradient of a matrix inverse A = X^{-1} is dX = -A^T dA A^T. The current implementation computes A @ (dA @ A) without transposes, which is mathematically incorrect and will produce wrong gradients for Akk.

🔎 Proposed fix
-    b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
-    b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
+    b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
+    b_dA = tl.dot(b_dA_t, tl.trans(b_A))
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bb03318 and ecff380.

📒 Files selected for processing (2)
  • fla/ops/kda/chunk.py (2 hunks)
  • fla/ops/kda/chunk_bwd.py (6 hunks)
🧰 Additional context used
🪛 Ruff (0.14.8)
fla/ops/kda/chunk_bwd.py

360-360: Unused function argument: meta

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (6)
fla/ops/kda/chunk.py (4)

8-8: LGTM: Import changes align with kernel fusion.

The imports correctly reflect the fused backward kernel and the recomputation approach. The removal of prepare_wy_repr_bwd is consistent with integrating WY backward logic into the main kernel.

Also applies to: 11-11


84-93: LGTM: Forward recomputation for backward pass.

The recomputation of w, u, qg, kg using recompute_w_u_fwd is a standard memory-saving technique. The explicit A=Akk parameter binding is clear and correct.


133-149: LGTM: Fused kernel call with extended parameters.

The fused kernel correctly receives:

  • v_org=v (original values) and v=v_new (transformed values) for WY backward computation
  • A=Akk for the inverse matrix operations
  • dv=dv as both input and output (accumulated gradients)
  • Additional outputs dAkk for downstream intra-chunk backward

This aligns with the fusion objectives to eliminate redundant memory traffic.


150-164: LGTM: Intra-chunk backward updated to use dAkk.

The chunk_kda_bwd_intra call correctly consumes dAkk from the fused kernel output alongside dAqk, maintaining proper gradient flow through the backward pass.

fla/ops/kda/chunk_bwd.py (2)

109-138: LGTM: Fused kernel signature with extended parameters.

The kernel signature correctly includes all parameters needed for the fused WY backward computation: v_org (original values), beta, A (Akk inverse), dv_in (input gradient), and additional outputs db and dA.


329-390: LGTM: Wrapper function correctly allocates and routes tensors.

The wrapper properly:

  • Allocates new output tensors (dv_out, db, dA)
  • Distinguishes between input dv (passed as dv_in) and output dv_out
  • Returns the extended gradient set (dq, dk, dv_out, db, dg, dA)

Note: The meta parameter at line 360 is required by Triton's grid function interface, so the static analysis warning is a false positive.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
fla/ops/kda/chunk_bwd.py (1)

353-358: Consider pre-allocating output tensors with correct dtype.

Output tensors dq, dk, and dg are allocated with dtype=torch.float, while dv2 and db use the input tensor's dtype. For consistency and potential performance benefits, consider explicitly specifying dtype=torch.float for all gradient outputs, or document why dv2 should match the input dtype.

Additionally, line 388 reassigns dv = dv2, which means the final dv output will have the same dtype as the input v. Ensure this matches the expected API behavior.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d2f6591 and 5d6d0cc.

📒 Files selected for processing (1)
  • fla/ops/kda/chunk_bwd.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/kda/chunk_bwd.py (1)
fla/utils.py (1)
  • check_shared_mem (447-453)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (3)
fla/ops/kda/chunk_bwd.py (3)

247-247: Verify gradient calculation requires transpose of A.

Similar to line 230, the gradient b_dkbg = tl.dot(b_A, b_dw) may require a transpose based on the forward pass formulation. The past review comment suggested: b_dkbg = tl.dot(tl.trans(b_A), b_dw_neg_cast).

Please verify with the forward pass whether matrix A should be transposed here.


360-388: Verify input dv is used correctly in the fused kernel.

The function accepts dv as an input parameter (line 340) and passes it to the kernel (line 374), but then overwrites it with dv2 (line 388) before returning.

Looking at the kernel signature (lines 122-123), both dv and dv2 are separate parameters:

  • dv appears to be read in the kernel (line 214: b_dv = tl.load(p_dv, ...))
  • dv2 is written as output (line 234: tl.store(p_dv2, ...))

Please verify:

  1. Is dv an input gradient that should be accumulated, or is it just a workspace?
  2. Should the function signature clarify that dv is an input and the returned dv is actually from dv2?

The current code suggests dv is an upstream gradient that gets accumulated in the kernel, then dv2 contains the final output. This flow should be documented.


252-257: Gate gradient formula is correct.

The formula at line 256 properly combines all gradient paths:

  1. Negative sign on b_k * b_dk is correct. The delta rule state update is S = S*exp(g) + β*k*(v - k^T*S). The key gradient involves both the forward contribution (through the state) and a correction term from the rank-1 update. The negative sign correctly represents this correction through the gate-scaled state.

  2. b_dgk masking with m_last is correct. The accumulated b_dgk represents gradients from both value and key paths that affect the state at the chunk boundary. It should only contribute to the gate gradient at the final position of each chunk, since that gate value governs the decay to subsequent chunks. Intermediate gates affect within-chunk computation through the cumsum, not at boundaries.

  3. The formula structure is validated. The test suite confirms gradient correctness with appropriate tolerances (0.02 for dg), and the pattern matches related implementations in chunk_intra.py. The exponential gate scaling via b_gn and the WY contribution via b_dg_wy are both properly incorporated.

Comment thread fla/ops/kda/chunk_bwd.py
Comment on lines +178 to +179
p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))
b_A = tl.load(p_A, boundary_check=(0, 1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for A allocation and block pointer usage in chunk_kda_bwd.py and related files
echo "=== Looking for A allocation in chunk_kda_bwd.py and related forward kernels ==="

# Find the wrapper function for chunk_kda_bwd_wy_dqkg_fused
rg -n "def chunk_kda_bwd_wy_dqkg_fused" fla/ops/kda/ -A 50 | head -100

echo ""
echo "=== Looking for A creation in forward or other files ==="

# Search for A.new_empty or similar allocations related to KDA
fd . fla/ops/kda/ --type f -name "*.py" | head -20

echo ""
echo "=== Check what A is in the backward kernel context ==="
rg -n "p_A.*make_block_ptr" fla/ops/kda/ -B 5 -A 2

Repository: fla-org/flash-linear-attention

Length of output: 10739


Correct block pointer configuration to match forward kernel memory layout.

The forward kernel (chunk_intra.py) allocates and accesses A with shape (T, BT) and strides (H*BT, 1). The backward kernel must use the same configuration to access the same memory layout correctly.

Change line 178 from:

p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))

To:

p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

The transposed configuration (BT, T) with strides (1, H*BT) is incompatible with how A is allocated in the forward pass and will cause incorrect memory access patterns.

🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 178 to 179, the block pointer for A is
configured with shape (BT, T) and strides (1, H * BT) which is transposed
relative to the forward kernel; replace that configuration so it matches the
forward kernel's allocation/access (shape (T, BT) and strides (H * BT, 1)),
i.e., change the make_block_ptr call to use (T, BT), (H * BT, 1), and adjust the
offset/transpose parameters accordingly so the pointer addresses the same memory
layout as the forward pass.

Comment thread fla/ops/kda/chunk_bwd.py
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
b_dA += tl.dot(b_dv, tl.trans(b_vb))

b_dvb = tl.dot(b_A, b_dv)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Find the forward kernel file
find . -name "*chunk_fwd*" -o -name "*kda*fwd*" | head -20

Repository: fla-org/flash-linear-attention

Length of output: 56


🏁 Script executed:

# List files in the kda directory to find forward pass
ls -la fla/ops/kda/

Repository: fla-org/flash-linear-attention

Length of output: 775


🏁 Script executed:

# Look at the forward chunk.py file to understand the forward pass
cat -n fla/ops/kda/chunk.py | head -100

Repository: fla-org/flash-linear-attention

Length of output: 3402


🏁 Script executed:

# Search for where b_dvb is computed and look at surrounding context
grep -n "b_dvb\|b_dkbg\|b_dA" fla/ops/kda/chunk_bwd.py | head -20

Repository: fla-org/flash-linear-attention

Length of output: 1046


🏁 Script executed:

# Look at the context around line 230 to understand the forward pass computation
sed -n '200,275p' fla/ops/kda/chunk_bwd.py

Repository: fla-org/flash-linear-attention

Length of output: 3681


🏁 Script executed:

# Look at the wy_fast.py file to see how the WY representation forward pass is computed
cat -n fla/ops/kda/wy_fast.py | grep -A 20 -B 5 "v_new\|vb"

Repository: fla-org/flash-linear-attention

Length of output: 3501


🏁 Script executed:

# Now let's verify the matrix inverse gradient computation at lines 268-269
sed -n '260,275p' fla/ops/kda/chunk_bwd.py

Repository: fla-org/flash-linear-attention

Length of output: 1022


🏁 Script executed:

# Check if there are any comments or documentation about the WY representation
grep -n "WY\|inverse\|transpose" fla/ops/kda/wy_fast.py | head -20

Repository: fla-org/flash-linear-attention

Length of output: 56


🏁 Script executed:

# Let me look at the complete recompute_w_u_fwd function to understand the forward pass better
sed -n '212,250p' fla/ops/kda/wy_fast.py

Repository: fla-org/flash-linear-attention

Length of output: 1134


🏁 Script executed:

# Let's examine the mathematical operation more carefully - check what the actual forward computation is
sed -n '40,75p' fla/ops/kda/wy_fast.py

Repository: fla-org/flash-linear-attention

Length of output: 1567


Missing transpose in backward gradient computation through WY representation.

Line 230 computes b_dvb = tl.dot(b_A, b_dv), but the forward pass applies b_u = tl.dot(b_A, b_vb) where b_vb = b_v * b_beta. For the gradient to flow correctly through a matrix multiplication, the backward pass must apply the transpose: b_dvb = tl.dot(tl.trans(b_A), b_dv).

The same issue affects line 247: b_dkbg = tl.dot(b_A, b_dw) should be b_dkbg = tl.dot(tl.trans(b_A), b_dw).

Additionally, lines 268-269 implement the gradient of matrix inverse incorrectly. The derivative of f(X^{-1}) requires dX = -X^{-T} @ dY @ X^{-T}, but the current code only computes A @ dA @ A without proper transposes.

🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 230, 247 and 268-269, the backward pass
applies matrix multiplications and inverse gradients without required
transposes: change line 230 b_dvb = tl.dot(b_A, b_dv) to b_dvb =
tl.dot(tl.trans(b_A), b_dv); change line 247 b_dkbg = tl.dot(b_A, b_dw) to
b_dkbg = tl.dot(tl.trans(b_A), b_dw); and replace the incorrect inverse-gradient
computation at lines 268-269 (which currently uses A @ dA @ A) with the correct
form using transposed inverses: dX = -A_T @ dY @ A_T where A_T is tl.trans(A)
(i.e., use A_T = tl.trans(A) and compute -tl.dot(tl.dot(A_T, dY), A_T)). Ensure
variable names match existing locals and use tl.trans consistently.

Comment thread fla/ops/kda/chunk_bwd.py
Comment on lines +266 to +270
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_dA = tl.where(m_A, b_dA, 0)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA = tl.where(m_A, -b_dA, 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Incorrect gradient calculation for matrix inverse.

The gradient of a matrix inverse is computed incorrectly. For A = X^{-1}, the derivative is:

dX = -A^T @ dA @ A^T

However, the current implementation computes:

b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)      # dA @ A
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))      # A @ (dA @ A)
b_dA = tl.where(m_A, -b_dA, 0)               # negation

This computes -(A @ dA @ A) instead of -(A^T @ dA @ A^T). The transposes are missing.

🔎 Proposed fix
-    b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
-    b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
+    b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
+    b_dA = tl.dot(b_dA_t, tl.trans(b_A))
     b_dA = tl.where(m_A, -b_dA, 0)
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 266 to 270, the gradient for the matrix
inverse is computed as -(A @ dA @ A) but must be -(A^T @ dA @ A^T); replace the
three dot calls with operations that compute b_dA = - (b_A.T @ b_dA @ b_A.T)
(respecting dtype casts as needed), then apply the existing mask m_A (i.e., b_dA
= tl.where(m_A, b_dA, 0)). Ensure you cast operands to matching dtypes before
each dot and perform the transposes on b_A (and on the intermediate if required)
so the final result matches -(A^T @ dA @ A^T).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
fla/ops/kda/chunk_bwd.py (4)

178-179: CRITICAL: Block pointer configuration for A still incorrect (previously flagged).

This issue was flagged in previous reviews but remains unfixed. The forward kernel allocates A with shape (T, BT) and strides (H*BT, 1). The backward kernel must use the same memory layout configuration.

🔎 Required fix
-    p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))
+    p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

Based on past review comments.


229-235: CRITICAL: Missing transpose in WY backward gradient computation (previously flagged).

Line 231 computes b_dvb = tl.dot(b_A, b_dv), but the forward pass applies u = A @ vb. For correct backpropagation through matrix multiplication, the gradient must use the transpose: dvb = A^T @ du.

🔎 Required fix
-        b_dvb = tl.dot(b_A, b_dv)
+        b_dvb = tl.dot(tl.trans(b_A), b_dv)

Based on past review comments.


244-249: CRITICAL: Missing transpose in K-branch gradient computation (previously flagged).

Line 247 computes b_dkbg = tl.dot(b_A, b_dw) without transposing b_A. For correct backpropagation through the matrix multiplication A @ x, the gradient must be A^T @ dx.

🔎 Required fix
-        b_dkbg = tl.dot(b_A, b_dw)
+        b_dkbg = tl.dot(tl.trans(b_A), b_dw)

Based on past review comments.


265-269: CRITICAL: Incorrect matrix inverse gradient formula (previously flagged).

The derivative of a matrix inverse Y = X^{-1} is dX = -Y^T @ dY @ Y^T. The current implementation computes -(A @ dA @ A) without the required transposes.

🔎 Required fix
-    b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
-    b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
+    b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
+    b_dA = tl.dot(b_dA_t, tl.trans(b_A))
     b_dA = tl.where(m_A, -b_dA, 0)

Based on past review comments.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 68de662 and 845d665.

📒 Files selected for processing (1)
  • fla/ops/kda/chunk_bwd.py (6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (4)
fla/ops/kda/chunk_bwd.py (4)

109-137: LGTM: Kernel signature updated to support fused WY backward.

The addition of v_new, beta, A as inputs and dv2, db, dA as outputs correctly extends the kernel API to support the fused WY representation backward pass within the same kernel launch.


139-152: LGTM: Grid simplified to 2D for fused kernel.

The removal of the i_k program axis is correct—the K-dimension loop is now handled within the kernel body (lines 184-263), eliminating one grid dimension and enabling the fusion.


328-388: LGTM: Wrapper correctly orchestrates fused kernel launch.

The wrapper properly allocates output tensors (dv2, db, dA), launches the fused kernel with the extended parameter set, and returns the expected gradients. The reassignment dv = dv2 at line 387 correctly uses the kernel's output before returning.


184-263: Approve fusion design, but correctness depends on fixing transpose issues.

The main loop successfully fuses the K-block and V-block iterations, computing WY backward contributions conditionally (i_k == 0) and keeping intermediate gradient dw in registers to avoid redundant global memory traffic. This achieves the PR's stated goal of reducing memory bandwidth and kernel launch overhead.

However, the kernel will produce incorrect results until the transpose issues at lines 231, 247, and 267-268 are fixed.

Comment thread fla/ops/kda/chunk_bwd.py
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dgk += tl.sum(b_dk * b_k, axis=0)
b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk
tl.debug_barrier()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Consider removing or conditionalizing the debug barrier.

The tl.debug_barrier() is useful for debugging synchronization issues but may impact performance in production. Consider removing it once the kernel is validated, or guard it with a compile-time flag.

🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around line 221, tl.debug_barrier() is left in the
hot kernel path; remove it for production or guard it behind a debug flag so it
only runs during development. Replace the bare call with either deletion, or a
conditional such as checking a module-level DEBUG/ENABLE_DEBUG_BARRIER constant
or an environment variable (e.g. if DEBUG: tl.debug_barrier()), or use a
compile-time flag so the barrier is compiled out in release builds.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
fla/ops/kda/chunk_bwd.py (5)

221-221: Remove or guard the debug barrier.

The tl.debug_barrier() at line 221 is in the hot kernel path and will impact performance. Remove it for production builds or guard it with a compile-time debug flag.

Based on past review comments.


178-179: Critical: Block pointer configuration for A is transposed.

The block pointer for matrix A uses shape (BT, T) with strides (1, H * BT), which is transposed relative to how A is allocated in the forward kernel. The forward pass allocates A with shape (T, BT) and strides (H * BT, 1). This mismatch will cause incorrect memory access and wrong gradients.

🔎 Correct configuration matching forward kernel
-    p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))
+    p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

Based on past review comments and forward kernel memory layout analysis.


230-230: Critical: Missing transpose in WY backward gradient.

Line 230 computes b_dvb = tl.dot(b_A, b_dv), but the forward pass applies u = A @ (v * beta). For correct backpropagation through a matrix multiplication, the gradient must flow through the transpose:

b_dvb = tl.dot(tl.trans(b_A), b_dv)

Without the transpose, the gradient computation is mathematically incorrect.

🔎 Proposed fix
-                b_dvb = tl.dot(b_A, b_dv)
+                b_dvb = tl.dot(tl.trans(b_A), b_dv)

Based on past review comments and backpropagation rules.


247-247: Critical: Missing transpose in gradient computation.

Line 247 computes b_dkgb = tl.dot(b_A, b_dw), but similar to line 230, the backward pass through matrix A requires the transpose for correct gradient flow:

b_dkgb = tl.dot(tl.trans(b_A), b_dw)

This is another instance of the same mathematical error.

🔎 Proposed fix
-        b_dkgb = tl.dot(b_A, b_dw)
+        b_dkgb = tl.dot(tl.trans(b_A), b_dw)

Based on past review comments.


264-268: Critical: Incorrect matrix inverse gradient formula.

Lines 266-267 compute the gradient of a matrix inverse as -(A @ dA @ A), but the correct formula is:

dX = -A^T @ dA @ A^T

where A = X^{-1}. The current implementation is missing transposes on both matrix multiplications, leading to incorrect gradients.

🔎 Correct implementation
-    b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
-    b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
+    b_dA_temp = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
+    b_dA = tl.dot(b_dA_temp, tl.trans(b_A))
     b_dA = tl.where(m_A, -b_dA, 0)

Based on past review comments and matrix calculus rules.

📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 845d665 and a26d3ba.

📒 Files selected for processing (1)
  • fla/ops/kda/chunk_bwd.py (8 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (2)
fla/ops/kda/chunk_bwd.py (2)

29-29: LGTM: Consistent naming convention.

The rename from chunk_bwd_kernel_dAv to chunk_kda_bwd_kernel_dAv improves naming consistency across the KDA operation suite.


327-387: Wrapper function structure is correct.

The wrapper function correctly allocates output tensors (dv2, db, dA), configures the grid for the fused kernel, and properly reassigns dv = dv2 to return the correct gradient path. The parameter passing aligns with the kernel signature.

However, the correctness depends on fixing the critical issues in the kernel itself (lines 178, 230, 247, 264-268).

Comment thread fla/ops/kda/chunk_bwd.py
Comment on lines +184 to +262
for i_k in range(tl.cdiv(K, BK)):
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K

p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))

p_gn = g + (min(T, i_t * BT + BT) - 1) * H*K + o_k
b_gn = tl.load(p_gn, mask=m_k, other=0)

b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
b_dgk = tl.zeros([BK], dtype=tl.float32)

p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v_new = tl.make_block_ptr(v_new, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# [BT, BV]
b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BT, BV]
b_dv = tl.load(p_dv, boundary_check=(0, 1))

b_dgk *= exp2(b_gn)
b_dq *= scale
b_dq = b_dq * exp2(b_g)
b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0)
b_dgk += tl.sum(b_h * b_dh, axis=0)
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
b_dk += tl.dot(b_v_new, b_dh.to(b_v_new.dtype))
b_dw += tl.dot(b_dv.to(b_v_new.dtype), b_h.to(b_v_new.dtype))

p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dgk += tl.sum(b_dk * b_k, axis=0)
b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk
tl.debug_barrier()
if i_k == 0:
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv2 = tl.make_block_ptr(dv2, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))

b_dA += tl.dot(b_dv, tl.trans(b_v))

b_dvb = tl.dot(b_A, b_dv)
b_dv2 = b_dvb * b_beta[:, None]
b_db += tl.sum(b_dvb * b_v, 1)

tl.store(p_dv2, b_dv2.to(p_dv2.dtype.element_ty), boundary_check=(0, 1))

b_gk_exp = exp2(b_g)
b_gb = b_gk_exp * b_beta[:, None]
b_dgk *= exp2(b_gn)
b_dq = b_dq * b_gk_exp * scale
b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0)

b_kg = b_k * b_gk_exp

b_dw = -b_dw.to(b_A.dtype)
b_dA += tl.dot(b_dw, tl.trans(b_kg.to(b_A.dtype)))

b_dkgb = tl.dot(b_A, b_dw)
b_db += tl.sum(b_dkgb * b_kg, 1)

p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_kdk = b_k * b_dk
b_dgk += tl.sum(b_kdk, axis=0)
b_dg = b_q * b_dq - b_kdk + m_last[:, None] * b_dgk + b_kg * b_dkgb * b_beta[:, None]
b_dk = b_dk + b_dkgb * b_gb

p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Good fusion design, but correctness depends on fixing critical bugs.

The interleaved K-block loop design successfully fuses WY backward computation with gradient accumulation, keeping intermediate dw in registers and eliminating redundant global memory traffic. This achieves the PR objective of reducing bandwidth and kernel launches.

However, the kernel contains multiple critical mathematical errors (missing transposes at lines 230, 247 and incorrect inverse gradient at lines 264-268) that must be fixed before the fused kernel can be used.

@yzhangcs yzhangcs merged commit 042d998 into main Dec 20, 2025
4 checks passed
@yzhangcs yzhangcs deleted the feature/fuse_bwd_kernels branch December 20, 2025 06:27
@coderabbitai coderabbitai bot mentioned this pull request Feb 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants