[KDA] fused bwd kernels inter and prepare wy#688
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughReplaces the previous KDA backward path with a fused WY+dqkg kernel/wrapper that accepts Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Nathancgy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly optimizes the backward pass for KDA operations by fusing two previously separate kernels into a single, more efficient kernel. This fusion strategy directly addresses memory bandwidth bottlenecks and reduces kernel launch overhead, leading to improved performance. The core idea is to keep intermediate gradients in registers, avoiding costly global memory access, and streamline the computation of other gradients by eliminating redundant write-read cycles. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant performance optimization by fusing two backward pass kernels, chunk_kda_bwd_kernel_inter and prepare_wy_repr_bwd_kda_kernel, into a single chunk_kda_bwd_kernel_inter_wy_fused kernel. This is a great approach to reduce memory bandwidth by keeping intermediate gradients like dw in registers. The changes in chunk.py and the removal of obsolete code in wy_fast.py are well-executed. However, my review of the new fused Triton kernel in chunk_inter.py has identified several critical issues related to incorrect memory addressing and gradient calculations that will lead to incorrect results. These issues need to be addressed to ensure the correctness of the backward pass.
| p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT,), (BT,), (0,)) | ||
| b_beta = tl.load(p_beta, boundary_check=(0,)) | ||
|
|
||
| p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) |
There was a problem hiding this comment.
The tl.make_block_ptr for matrix A is configured incorrectly. The shape and strides seem to be for a transposed view, and the start offset calculation is incorrect. This will lead to out-of-bounds memory access and incorrect results. The shape of A for a given head is (T, BT), so the pointer should be configured accordingly.
| p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) | |
| p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
| b_dw_neg_cast = b_dw_neg.to(b_A.dtype) | ||
| b_dA += tl.dot(b_dw_neg_cast, tl.trans(b_kbg)) | ||
|
|
||
| b_dkbg = tl.dot(b_A, b_dw_neg_cast) |
There was a problem hiding this comment.
The gradient calculation for b_dkbg is incorrect. The forward pass involves a multiplication with A (which represents Akk_inv), so the backward pass requires multiplication with A.T. The code is missing the transpose on b_A.
| b_dkbg = tl.dot(b_A, b_dw_neg_cast) | |
| b_dkbg = tl.dot(tl.trans(b_A), b_dw_neg_cast) |
|
|
||
| b_dA += tl.dot(b_du, tl.trans(b_vb)) | ||
|
|
||
| b_dvb = tl.dot(b_A, b_du) |
There was a problem hiding this comment.
| b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) | ||
| b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) |
There was a problem hiding this comment.
The gradient calculation for dAkk from dAkk_inv (represented by b_dA) is incorrect. The derivative of a matrix inverse A = X^{-1} is dX = -A^T dA A^T. The current implementation computes A @ (dA @ A), which is mathematically incorrect. The multiplication should be with the transpose of b_A.
| b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) | |
| b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) | |
| b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype)) | |
| b_dA = tl.dot(b_dA_t, tl.trans(b_A)) |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
fla/ops/kda/chunk_inter.py (1)
231-232: Unusedmetaparameter in grid function.The
metaparameter is required by Triton's grid API but unused here. This is a known pattern and the static analysis warning can be safely suppressed with a leading underscore.🔎 Suggested fix:
- def grid(meta): + def grid(_meta): return (NT, B * H)
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/ops/kda/chunk.py(2 hunks)fla/ops/kda/chunk_inter.py(5 hunks)fla/ops/kda/wy_fast.py(0 hunks)
💤 Files with no reviewable changes (1)
- fla/ops/kda/wy_fast.py
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/kda/chunk.py (2)
fla/ops/kda/chunk_inter.py (1)
chunk_kda_bwd_dqkwg_wy_fused(200-261)fla/ops/kda/wy_fast.py (1)
recompute_w_u_fwd(103-147)
🪛 GitHub Actions: lint
fla/ops/kda/chunk_inter.py
[error] 1-1: End-of-file-fixer: Fixed missing newline at end of file.
[error] 1-1: Ruff: 1 issue fixed during pre-commit run (formatting/linting).
🪛 Ruff (0.14.8)
fla/ops/kda/chunk_inter.py
231-231: Unused function argument: meta
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (6)
fla/ops/kda/chunk.py (2)
9-12: LGTM! Import changes align with kernel fusion.The imports correctly reflect the transition from separate kernels to the fused
chunk_kda_bwd_dqkwg_wy_fused, withprepare_wy_repr_bwdappropriately removed since its logic is now integrated into the fused kernel.
140-158: LGTM! Fused kernel integration is well-structured.The transition to
chunk_kda_bwd_dqkwg_wy_fusedis correctly implemented:
v_org(original v) is passed alongside transformedv_newA=Akkproperly supplies the attention matrix- Additional output
dAkkis correctly propagated tochunk_kda_bwd_intra- The explanatory comments clearly document the memory bandwidth optimization rationale
fla/ops/kda/chunk_inter.py (4)
31-60: LGTM! Well-designed fused kernel signature.The kernel signature properly extends the original with:
v_org,beta,Ainputs for WY backward computationdv_infor input gradients,db,dAfor new outputs- Heuristics and autotune configurations appropriately retained
61-95: LGTM! Pointer setup and indexing are correct.The simplified program ID handling (
i_t,i_bh) and pointer arithmetic for all tensors including the newv_org,beta,A,dv_in,db,dAcorrectly match the expected memory layouts. Variable-length sequence support viaIS_VARLENis properly preserved.
106-170: LGTM! Core computation loop correctly fuses inter-chunk and WY backward logic.Key observations:
b_dwis kept in registers and never written to global memory (per PR objectives)- WY backward contributions computed inline (
b_dk_wy,b_dg_wy,b_db)b_dk = b_dk_inter + b_dk_wyproperly combines both gradient sourcesb_dAaccumulation viatl.dot(b_dw_neg_cast, tl.trans(b_kbg))is correctThe interleaved K-block loop design eliminates the redundant global memory traffic.
172-197: LGTM! V-dimension loop and dA finalization are correctly implemented.The V-loop properly accumulates additional
b_dAandb_dbcontributions fromv_org. ThedAfinalization sequence (masking → twotl.dotwith transposedb_A→ final masking with negation) correctly computes the gradient through the lower-triangular structure.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/ops/kda/chunk_bwd.py (1)
109-276: Fusion design is sound, but critical correctness issues must be fixed first.The kernel fusion approach successfully integrates WY backward preparation into the main backward kernel, keeping intermediate gradients in registers. However, the four critical mathematical errors (lines 178, 231, 261, 268-269) will prevent correct gradient computation and must be addressed before this PR can be merged.
Once these issues are resolved, the fused kernel should deliver the performance benefits described in the PR objectives.
♻️ Duplicate comments (4)
fla/ops/kda/chunk_bwd.py (4)
178-179: CRITICAL: Block pointer for matrix A configured incorrectly.The
tl.make_block_ptrconfiguration has swapped dimensions, strides, and offsets. After the pointer offset at line 163, matrix A has shape[T, BT]with strides(H*BT, 1). The current configuration will cause out-of-bounds memory access and incorrect gradient computation.🔎 Proposed fix
- p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
231-231: CRITICAL: Missing transpose in gradient calculation.The gradient
b_dkbgis computed without transposingb_A. Since the forward pass involvesA @ x, the backward pass requiresA^T @ dx. This error will produce incorrect gradients for the key and beta parameters.🔎 Proposed fix
- b_dkbg = tl.dot(b_A, b_dw_neg_cast) + b_dkbg = tl.dot(tl.trans(b_A), b_dw_neg_cast)
261-261: CRITICAL: Missing transpose in gradient calculation for dv.Similar to line 231, the gradient
b_dvbis computed without transposingb_A. The backpropagation throughA @ urequiresA^T @ du, so the transpose is mathematically necessary.🔎 Proposed fix
- b_dvb = tl.dot(b_A, b_du) + b_dvb = tl.dot(tl.trans(b_A), b_du)
268-269: CRITICAL: Incorrect gradient formula for matrix inverse.The gradient of a matrix inverse
A = X^{-1}isdX = -A^T dA A^T. The current implementation computesA @ (dA @ A)without transposes, which is mathematically incorrect and will produce wrong gradients forAkk.🔎 Proposed fix
- b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) - b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype)) + b_dA = tl.dot(b_dA_t, tl.trans(b_A))
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/ops/kda/chunk.py(2 hunks)fla/ops/kda/chunk_bwd.py(6 hunks)
🧰 Additional context used
🪛 Ruff (0.14.8)
fla/ops/kda/chunk_bwd.py
360-360: Unused function argument: meta
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (6)
fla/ops/kda/chunk.py (4)
8-8: LGTM: Import changes align with kernel fusion.The imports correctly reflect the fused backward kernel and the recomputation approach. The removal of
prepare_wy_repr_bwdis consistent with integrating WY backward logic into the main kernel.Also applies to: 11-11
84-93: LGTM: Forward recomputation for backward pass.The recomputation of
w,u,qg,kgusingrecompute_w_u_fwdis a standard memory-saving technique. The explicitA=Akkparameter binding is clear and correct.
133-149: LGTM: Fused kernel call with extended parameters.The fused kernel correctly receives:
v_org=v(original values) andv=v_new(transformed values) for WY backward computationA=Akkfor the inverse matrix operationsdv=dvas both input and output (accumulated gradients)- Additional outputs
dAkkfor downstream intra-chunk backwardThis aligns with the fusion objectives to eliminate redundant memory traffic.
150-164: LGTM: Intra-chunk backward updated to use dAkk.The
chunk_kda_bwd_intracall correctly consumesdAkkfrom the fused kernel output alongsidedAqk, maintaining proper gradient flow through the backward pass.fla/ops/kda/chunk_bwd.py (2)
109-138: LGTM: Fused kernel signature with extended parameters.The kernel signature correctly includes all parameters needed for the fused WY backward computation:
v_org(original values),beta,A(Akk inverse),dv_in(input gradient), and additional outputsdbanddA.
329-390: LGTM: Wrapper function correctly allocates and routes tensors.The wrapper properly:
- Allocates new output tensors (
dv_out,db,dA)- Distinguishes between input
dv(passed asdv_in) and outputdv_out- Returns the extended gradient set
(dq, dk, dv_out, db, dg, dA)Note: The
metaparameter at line 360 is required by Triton's grid function interface, so the static analysis warning is a false positive.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
fla/ops/kda/chunk_bwd.py (1)
353-358: Consider pre-allocating output tensors with correct dtype.Output tensors
dq,dk, anddgare allocated withdtype=torch.float, whiledv2anddbuse the input tensor's dtype. For consistency and potential performance benefits, consider explicitly specifyingdtype=torch.floatfor all gradient outputs, or document whydv2should match the input dtype.Additionally, line 388 reassigns
dv = dv2, which means the finaldvoutput will have the same dtype as the inputv. Ensure this matches the expected API behavior.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/kda/chunk_bwd.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/kda/chunk_bwd.py (1)
fla/utils.py (1)
check_shared_mem(447-453)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (3)
fla/ops/kda/chunk_bwd.py (3)
247-247: Verify gradient calculation requires transpose of A.Similar to line 230, the gradient
b_dkbg = tl.dot(b_A, b_dw)may require a transpose based on the forward pass formulation. The past review comment suggested:b_dkbg = tl.dot(tl.trans(b_A), b_dw_neg_cast).Please verify with the forward pass whether matrix
Ashould be transposed here.
360-388: Verify input dv is used correctly in the fused kernel.The function accepts
dvas an input parameter (line 340) and passes it to the kernel (line 374), but then overwrites it withdv2(line 388) before returning.Looking at the kernel signature (lines 122-123), both
dvanddv2are separate parameters:
dvappears to be read in the kernel (line 214:b_dv = tl.load(p_dv, ...))dv2is written as output (line 234:tl.store(p_dv2, ...))Please verify:
- Is
dvan input gradient that should be accumulated, or is it just a workspace?- Should the function signature clarify that
dvis an input and the returneddvis actually fromdv2?The current code suggests
dvis an upstream gradient that gets accumulated in the kernel, thendv2contains the final output. This flow should be documented.
252-257: Gate gradient formula is correct.The formula at line 256 properly combines all gradient paths:
Negative sign on
b_k * b_dkis correct. The delta rule state update isS = S*exp(g) + β*k*(v - k^T*S). The key gradient involves both the forward contribution (through the state) and a correction term from the rank-1 update. The negative sign correctly represents this correction through the gate-scaled state.
b_dgkmasking withm_lastis correct. The accumulatedb_dgkrepresents gradients from both value and key paths that affect the state at the chunk boundary. It should only contribute to the gate gradient at the final position of each chunk, since that gate value governs the decay to subsequent chunks. Intermediate gates affect within-chunk computation through the cumsum, not at boundaries.The formula structure is validated. The test suite confirms gradient correctness with appropriate tolerances (0.02 for
dg), and the pattern matches related implementations inchunk_intra.py. The exponential gate scaling viab_gnand the WY contribution viab_dg_wyare both properly incorporated.
| p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) | ||
| b_A = tl.load(p_A, boundary_check=(0, 1)) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for A allocation and block pointer usage in chunk_kda_bwd.py and related files
echo "=== Looking for A allocation in chunk_kda_bwd.py and related forward kernels ==="
# Find the wrapper function for chunk_kda_bwd_wy_dqkg_fused
rg -n "def chunk_kda_bwd_wy_dqkg_fused" fla/ops/kda/ -A 50 | head -100
echo ""
echo "=== Looking for A creation in forward or other files ==="
# Search for A.new_empty or similar allocations related to KDA
fd . fla/ops/kda/ --type f -name "*.py" | head -20
echo ""
echo "=== Check what A is in the backward kernel context ==="
rg -n "p_A.*make_block_ptr" fla/ops/kda/ -B 5 -A 2Repository: fla-org/flash-linear-attention
Length of output: 10739
Correct block pointer configuration to match forward kernel memory layout.
The forward kernel (chunk_intra.py) allocates and accesses A with shape (T, BT) and strides (H*BT, 1). The backward kernel must use the same configuration to access the same memory layout correctly.
Change line 178 from:
p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))To:
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))The transposed configuration (BT, T) with strides (1, H*BT) is incompatible with how A is allocated in the forward pass and will cause incorrect memory access patterns.
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 178 to 179, the block pointer for A is
configured with shape (BT, T) and strides (1, H * BT) which is transposed
relative to the forward kernel; replace that configuration so it matches the
forward kernel's allocation/access (shape (T, BT) and strides (H * BT, 1)),
i.e., change the make_block_ptr call to use (T, BT), (H * BT, 1), and adjust the
offset/transpose parameters accordingly so the pointer addresses the same memory
layout as the forward pass.
| tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) | ||
| b_dA += tl.dot(b_dv, tl.trans(b_vb)) | ||
|
|
||
| b_dvb = tl.dot(b_A, b_dv) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find the forward kernel file
find . -name "*chunk_fwd*" -o -name "*kda*fwd*" | head -20Repository: fla-org/flash-linear-attention
Length of output: 56
🏁 Script executed:
# List files in the kda directory to find forward pass
ls -la fla/ops/kda/Repository: fla-org/flash-linear-attention
Length of output: 775
🏁 Script executed:
# Look at the forward chunk.py file to understand the forward pass
cat -n fla/ops/kda/chunk.py | head -100Repository: fla-org/flash-linear-attention
Length of output: 3402
🏁 Script executed:
# Search for where b_dvb is computed and look at surrounding context
grep -n "b_dvb\|b_dkbg\|b_dA" fla/ops/kda/chunk_bwd.py | head -20Repository: fla-org/flash-linear-attention
Length of output: 1046
🏁 Script executed:
# Look at the context around line 230 to understand the forward pass computation
sed -n '200,275p' fla/ops/kda/chunk_bwd.pyRepository: fla-org/flash-linear-attention
Length of output: 3681
🏁 Script executed:
# Look at the wy_fast.py file to see how the WY representation forward pass is computed
cat -n fla/ops/kda/wy_fast.py | grep -A 20 -B 5 "v_new\|vb"Repository: fla-org/flash-linear-attention
Length of output: 3501
🏁 Script executed:
# Now let's verify the matrix inverse gradient computation at lines 268-269
sed -n '260,275p' fla/ops/kda/chunk_bwd.pyRepository: fla-org/flash-linear-attention
Length of output: 1022
🏁 Script executed:
# Check if there are any comments or documentation about the WY representation
grep -n "WY\|inverse\|transpose" fla/ops/kda/wy_fast.py | head -20Repository: fla-org/flash-linear-attention
Length of output: 56
🏁 Script executed:
# Let me look at the complete recompute_w_u_fwd function to understand the forward pass better
sed -n '212,250p' fla/ops/kda/wy_fast.pyRepository: fla-org/flash-linear-attention
Length of output: 1134
🏁 Script executed:
# Let's examine the mathematical operation more carefully - check what the actual forward computation is
sed -n '40,75p' fla/ops/kda/wy_fast.pyRepository: fla-org/flash-linear-attention
Length of output: 1567
Missing transpose in backward gradient computation through WY representation.
Line 230 computes b_dvb = tl.dot(b_A, b_dv), but the forward pass applies b_u = tl.dot(b_A, b_vb) where b_vb = b_v * b_beta. For the gradient to flow correctly through a matrix multiplication, the backward pass must apply the transpose: b_dvb = tl.dot(tl.trans(b_A), b_dv).
The same issue affects line 247: b_dkbg = tl.dot(b_A, b_dw) should be b_dkbg = tl.dot(tl.trans(b_A), b_dw).
Additionally, lines 268-269 implement the gradient of matrix inverse incorrectly. The derivative of f(X^{-1}) requires dX = -X^{-T} @ dY @ X^{-T}, but the current code only computes A @ dA @ A without proper transposes.
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 230, 247 and 268-269, the backward pass
applies matrix multiplications and inverse gradients without required
transposes: change line 230 b_dvb = tl.dot(b_A, b_dv) to b_dvb =
tl.dot(tl.trans(b_A), b_dv); change line 247 b_dkbg = tl.dot(b_A, b_dw) to
b_dkbg = tl.dot(tl.trans(b_A), b_dw); and replace the incorrect inverse-gradient
computation at lines 268-269 (which currently uses A @ dA @ A) with the correct
form using transposed inverses: dX = -A_T @ dY @ A_T where A_T is tl.trans(A)
(i.e., use A_T = tl.trans(A) and compute -tl.dot(tl.dot(A_T, dY), A_T)). Ensure
variable names match existing locals and use tl.trans consistently.
| m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) | ||
| b_dA = tl.where(m_A, b_dA, 0) | ||
| b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) | ||
| b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) | ||
| b_dA = tl.where(m_A, -b_dA, 0) |
There was a problem hiding this comment.
Critical: Incorrect gradient calculation for matrix inverse.
The gradient of a matrix inverse is computed incorrectly. For A = X^{-1}, the derivative is:
dX = -A^T @ dA @ A^T
However, the current implementation computes:
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) # dA @ A
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) # A @ (dA @ A)
b_dA = tl.where(m_A, -b_dA, 0) # negationThis computes -(A @ dA @ A) instead of -(A^T @ dA @ A^T). The transposes are missing.
🔎 Proposed fix
- b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
- b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
+ b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
+ b_dA = tl.dot(b_dA_t, tl.trans(b_A))
b_dA = tl.where(m_A, -b_dA, 0)🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 266 to 270, the gradient for the matrix
inverse is computed as -(A @ dA @ A) but must be -(A^T @ dA @ A^T); replace the
three dot calls with operations that compute b_dA = - (b_A.T @ b_dA @ b_A.T)
(respecting dtype casts as needed), then apply the existing mask m_A (i.e., b_dA
= tl.where(m_A, b_dA, 0)). Ensure you cast operands to matching dtypes before
each dot and perform the transposes on b_A (and on the intermediate if required)
so the final result matches -(A^T @ dA @ A^T).
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (4)
fla/ops/kda/chunk_bwd.py (4)
178-179: CRITICAL: Block pointer configuration for A still incorrect (previously flagged).This issue was flagged in previous reviews but remains unfixed. The forward kernel allocates
Awith shape(T, BT)and strides(H*BT, 1). The backward kernel must use the same memory layout configuration.🔎 Required fix
- p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))Based on past review comments.
229-235: CRITICAL: Missing transpose in WY backward gradient computation (previously flagged).Line 231 computes
b_dvb = tl.dot(b_A, b_dv), but the forward pass appliesu = A @ vb. For correct backpropagation through matrix multiplication, the gradient must use the transpose:dvb = A^T @ du.🔎 Required fix
- b_dvb = tl.dot(b_A, b_dv) + b_dvb = tl.dot(tl.trans(b_A), b_dv)Based on past review comments.
244-249: CRITICAL: Missing transpose in K-branch gradient computation (previously flagged).Line 247 computes
b_dkbg = tl.dot(b_A, b_dw)without transposingb_A. For correct backpropagation through the matrix multiplicationA @ x, the gradient must beA^T @ dx.🔎 Required fix
- b_dkbg = tl.dot(b_A, b_dw) + b_dkbg = tl.dot(tl.trans(b_A), b_dw)Based on past review comments.
265-269: CRITICAL: Incorrect matrix inverse gradient formula (previously flagged).The derivative of a matrix inverse
Y = X^{-1}isdX = -Y^T @ dY @ Y^T. The current implementation computes-(A @ dA @ A)without the required transposes.🔎 Required fix
- b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) - b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype)) + b_dA = tl.dot(b_dA_t, tl.trans(b_A)) b_dA = tl.where(m_A, -b_dA, 0)Based on past review comments.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/kda/chunk_bwd.py(6 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (4)
fla/ops/kda/chunk_bwd.py (4)
109-137: LGTM: Kernel signature updated to support fused WY backward.The addition of
v_new,beta,Aas inputs anddv2,db,dAas outputs correctly extends the kernel API to support the fused WY representation backward pass within the same kernel launch.
139-152: LGTM: Grid simplified to 2D for fused kernel.The removal of the
i_kprogram axis is correct—the K-dimension loop is now handled within the kernel body (lines 184-263), eliminating one grid dimension and enabling the fusion.
328-388: LGTM: Wrapper correctly orchestrates fused kernel launch.The wrapper properly allocates output tensors (
dv2,db,dA), launches the fused kernel with the extended parameter set, and returns the expected gradients. The reassignmentdv = dv2at line 387 correctly uses the kernel's output before returning.
184-263: Approve fusion design, but correctness depends on fixing transpose issues.The main loop successfully fuses the K-block and V-block iterations, computing WY backward contributions conditionally (
i_k == 0) and keeping intermediate gradientdwin registers to avoid redundant global memory traffic. This achieves the PR's stated goal of reducing memory bandwidth and kernel launch overhead.However, the kernel will produce incorrect results until the transpose issues at lines 231, 247, and 267-268 are fixed.
| b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
| b_dgk += tl.sum(b_dk * b_k, axis=0) | ||
| b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk | ||
| tl.debug_barrier() |
There was a problem hiding this comment.
Consider removing or conditionalizing the debug barrier.
The tl.debug_barrier() is useful for debugging synchronization issues but may impact performance in production. Consider removing it once the kernel is validated, or guard it with a compile-time flag.
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around line 221, tl.debug_barrier() is left in the
hot kernel path; remove it for production or guard it behind a debug flag so it
only runs during development. Replace the bare call with either deletion, or a
conditional such as checking a module-level DEBUG/ENABLE_DEBUG_BARRIER constant
or an environment variable (e.g. if DEBUG: tl.debug_barrier()), or use a
compile-time flag so the barrier is compiled out in release builds.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (5)
fla/ops/kda/chunk_bwd.py (5)
221-221: Remove or guard the debug barrier.The
tl.debug_barrier()at line 221 is in the hot kernel path and will impact performance. Remove it for production builds or guard it with a compile-time debug flag.Based on past review comments.
178-179: Critical: Block pointer configuration for A is transposed.The block pointer for matrix
Auses shape(BT, T)with strides(1, H * BT), which is transposed relative to howAis allocated in the forward kernel. The forward pass allocatesAwith shape(T, BT)and strides(H * BT, 1). This mismatch will cause incorrect memory access and wrong gradients.🔎 Correct configuration matching forward kernel
- p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))Based on past review comments and forward kernel memory layout analysis.
230-230: Critical: Missing transpose in WY backward gradient.Line 230 computes
b_dvb = tl.dot(b_A, b_dv), but the forward pass appliesu = A @ (v * beta). For correct backpropagation through a matrix multiplication, the gradient must flow through the transpose:b_dvb = tl.dot(tl.trans(b_A), b_dv)Without the transpose, the gradient computation is mathematically incorrect.
🔎 Proposed fix
- b_dvb = tl.dot(b_A, b_dv) + b_dvb = tl.dot(tl.trans(b_A), b_dv)Based on past review comments and backpropagation rules.
247-247: Critical: Missing transpose in gradient computation.Line 247 computes
b_dkgb = tl.dot(b_A, b_dw), but similar to line 230, the backward pass through matrixArequires the transpose for correct gradient flow:b_dkgb = tl.dot(tl.trans(b_A), b_dw)This is another instance of the same mathematical error.
🔎 Proposed fix
- b_dkgb = tl.dot(b_A, b_dw) + b_dkgb = tl.dot(tl.trans(b_A), b_dw)Based on past review comments.
264-268: Critical: Incorrect matrix inverse gradient formula.Lines 266-267 compute the gradient of a matrix inverse as
-(A @ dA @ A), but the correct formula is:dX = -A^T @ dA @ A^Twhere
A = X^{-1}. The current implementation is missing transposes on both matrix multiplications, leading to incorrect gradients.🔎 Correct implementation
- b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) - b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA_temp = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype)) + b_dA = tl.dot(b_dA_temp, tl.trans(b_A)) b_dA = tl.where(m_A, -b_dA, 0)Based on past review comments and matrix calculus rules.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/kda/chunk_bwd.py(8 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (2)
fla/ops/kda/chunk_bwd.py (2)
29-29: LGTM: Consistent naming convention.The rename from
chunk_bwd_kernel_dAvtochunk_kda_bwd_kernel_dAvimproves naming consistency across the KDA operation suite.
327-387: Wrapper function structure is correct.The wrapper function correctly allocates output tensors (
dv2,db,dA), configures the grid for the fused kernel, and properly reassignsdv = dv2to return the correct gradient path. The parameter passing aligns with the kernel signature.However, the correctness depends on fixing the critical issues in the kernel itself (lines 178, 230, 247, 264-268).
| for i_k in range(tl.cdiv(K, BK)): | ||
| o_k = i_k * BK + tl.arange(0, BK) | ||
| m_k = o_k < K | ||
|
|
||
| p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
| b_g = tl.load(p_g, boundary_check=(0, 1)) | ||
|
|
||
| p_gn = g + (min(T, i_t * BT + BT) - 1) * H*K + o_k | ||
| b_gn = tl.load(p_gn, mask=m_k, other=0) | ||
|
|
||
| b_dq = tl.zeros([BT, BK], dtype=tl.float32) | ||
| b_dk = tl.zeros([BT, BK], dtype=tl.float32) | ||
| b_dw = tl.zeros([BT, BK], dtype=tl.float32) | ||
| b_dgk = tl.zeros([BK], dtype=tl.float32) | ||
|
|
||
| p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) | ||
| for i_v in range(tl.cdiv(V, BV)): | ||
| p_v_new = tl.make_block_ptr(v_new, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) | ||
| p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) | ||
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| # [BT, BV] | ||
| b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| # [BV, BK] | ||
| b_h = tl.load(p_h, boundary_check=(0, 1)) | ||
| b_dh = tl.load(p_dh, boundary_check=(0, 1)) | ||
| # [BT, BV] | ||
| b_dv = tl.load(p_dv, boundary_check=(0, 1)) | ||
|
|
||
| b_dgk *= exp2(b_gn) | ||
| b_dq *= scale | ||
| b_dq = b_dq * exp2(b_g) | ||
| b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0) | ||
| b_dgk += tl.sum(b_h * b_dh, axis=0) | ||
| b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) | ||
| b_dk += tl.dot(b_v_new, b_dh.to(b_v_new.dtype)) | ||
| b_dw += tl.dot(b_dv.to(b_v_new.dtype), b_h.to(b_v_new.dtype)) | ||
|
|
||
| p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
| b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
| b_dgk += tl.sum(b_dk * b_k, axis=0) | ||
| b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk | ||
| tl.debug_barrier() | ||
| if i_k == 0: | ||
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dv2 = tl.make_block_ptr(dv2, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
|
|
||
| tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||
| tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) | ||
| tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) | ||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
|
|
||
| b_dA += tl.dot(b_dv, tl.trans(b_v)) | ||
|
|
||
| b_dvb = tl.dot(b_A, b_dv) | ||
| b_dv2 = b_dvb * b_beta[:, None] | ||
| b_db += tl.sum(b_dvb * b_v, 1) | ||
|
|
||
| tl.store(p_dv2, b_dv2.to(p_dv2.dtype.element_ty), boundary_check=(0, 1)) | ||
|
|
||
| b_gk_exp = exp2(b_g) | ||
| b_gb = b_gk_exp * b_beta[:, None] | ||
| b_dgk *= exp2(b_gn) | ||
| b_dq = b_dq * b_gk_exp * scale | ||
| b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0) | ||
|
|
||
| b_kg = b_k * b_gk_exp | ||
|
|
||
| b_dw = -b_dw.to(b_A.dtype) | ||
| b_dA += tl.dot(b_dw, tl.trans(b_kg.to(b_A.dtype))) | ||
|
|
||
| b_dkgb = tl.dot(b_A, b_dw) | ||
| b_db += tl.sum(b_dkgb * b_kg, 1) | ||
|
|
||
| p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
| b_kdk = b_k * b_dk | ||
| b_dgk += tl.sum(b_kdk, axis=0) | ||
| b_dg = b_q * b_dq - b_kdk + m_last[:, None] * b_dgk + b_kg * b_dkgb * b_beta[:, None] | ||
| b_dk = b_dk + b_dkgb * b_gb | ||
|
|
||
| p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||
| tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) | ||
| tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Good fusion design, but correctness depends on fixing critical bugs.
The interleaved K-block loop design successfully fuses WY backward computation with gradient accumulation, keeping intermediate dw in registers and eliminating redundant global memory traffic. This achieves the PR objective of reducing bandwidth and kernel launches.
However, the kernel contains multiple critical mathematical errors (missing transposes at lines 230, 247 and incorrect inverse gradient at lines 264-268) that must be fixed before the fused kernel can be used.
8 * 4k nsys profile:
mean (ns)
fused: 1151008.7
not fused:
780636.6 + 652113.3
This PR fuses chunk_kda_bwd_kernel_inter and prepare_wy_repr_bwd_kda_kernel into a single kernel to reduce memory bandwidth usage in the backward pass.
Previously, the backward pass computed inter-chunk gradients in one kernel (writing dw, dk, dg to global memory), then immediately read them back in the WY backward kernel. The fused kernel eliminates this redundant memory traffic by keeping dw in registers and computing the WY backward contributions in the same kernel launch.
Key changes:
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.