[KDA] Speed up chunk_kda by introducing lowerbound gate#703
Conversation
|
Caution Review failedThe pull request is closed. Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a lower-bound-aware chunked gating path and safe_gate flag, threads safe_gate/lower_bound/disable_recompute through forward/backward, expands saved intermediates (w, u, qg, kg, v_new, h), introduces an intra sub-chunk kernel with gather/varlen branches, and updates recompute/WY wiring and tests. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant ChunkFunc as ChunkKDAFunction
participant Indexer as prepare_chunk_indices
participant Gate as kda_gate_chunk_cumsum
participant Intra as chunk_kda_fwd_intra
participant WY as recompute_w_u_fwd_kda
participant Mem as MemoryStore
Client->>ChunkFunc: forward(q,k,v,..., safe_gate, lower_bound, disable_recompute)
alt cu_seqlens provided
ChunkFunc->>Indexer: prepare_chunk_indices(cu_seqlens)
Indexer-->>ChunkFunc: chunk_indices
end
alt safe_gate enabled
ChunkFunc->>Gate: compute gated A/g (lower_bound, scale)
Gate-->>ChunkFunc: gate outputs (A_log, g, dt_bias, ...)
end
ChunkFunc->>Intra: chunk_kda_fwd_intra(..., safe_gate, disable_recompute, chunk_indices)
Intra-->>ChunkFunc: (w,u,qg,kg,Aqk,Akk,v_new,h)
alt recompute allowed
ChunkFunc->>WY: recompute_w_u_fwd_kda(...)
WY-->>ChunkFunc: (w,u,qg,kg)
else
ChunkFunc->>Mem: preserve intermediates for backward
end
ChunkFunc->>Mem: save state (w,u,qg,kg,v_new,h,safe_gate,lower_bound,disable_recompute)
Client->>ChunkFunc: backward(grad_out)
ChunkFunc->>ChunkFunc: restore or recompute intermediates
alt safe_gate enabled
ChunkFunc->>Gate: backward through gate (uses saved/recomputed intermediates)
end
ChunkFunc-->>Client: gradients (dq,dk,dv,dA_log,ddt_bias)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: Repository UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on optimizing the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces optimizations to chunk_kda by adding a safe_gate option and a disable_recompute flag to trade memory for speed. The changes are extensive, touching the forward and backward passes, and introducing new Triton kernels and fallback paths for older Triton versions.
My review has identified a critical issue where a dimension assertion was removed while the underlying kernels still have the limitation, which could lead to runtime errors. I've also pointed out a minor style inconsistency and a potential precision degradation that comes with one of the optimizations.
Overall, the changes are substantial and aim to improve performance, but the critical issue should be addressed.
|
|
||
| if IS_TF32_SUPPORTED: | ||
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32x3') | ||
| SOLVE_TRIL_DOT_PRECISION = tl.constexpr('tf32') |
There was a problem hiding this comment.
Changing SOLVE_TRIL_DOT_PRECISION from 'tf32x3' to 'tf32' may improve performance but can reduce numerical precision. This seems to be reflected in the increased error tolerances in the associated tests. This precision-performance trade-off should be documented with a code comment if it is intentional.
29027fc to
c4a9f93
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tests/ops/test_kda.py (2)
311-313: Consider documenting the safe_gate restriction.The assertion
assert use_gate_in_kernel is Falsewhensafe_gateis True restricts safe_gate testing to non-kernel gate paths in varlen tests. If this is intentional (e.g., safe_gate with kernel gating not yet implemented for varlen), consider adding a comment orpytest.skipwith a reason.Looking at the test cases (lines 273-277),
safe_gate=Trueonly appears withuse_gate_in_kernel=False, so this is consistent—but explicit documentation would help future maintainers.
355-356: Verify reference function selection for varlen tests with safe_gate.The varlen test uses
naive_kda_gateregardless ofsafe_gatevalue, whiletest_chunk(line 200) selects betweennaive_kda_gateandnaive_kda_lowerbound_gatebased onsafe_gate.Given that
safe_gatetest cases in varlen all haveuse_gate_in_kernel=False(lines 273-277), the gate function selection here may not matter since the kernel gate path isn't exercised. However, for consistency and future-proofing, consider using the same conditional selection pattern.fla/ops/kda/chunk.py (1)
435-437: Improve docstring for cu_seqlens_cpu.The docstring for
cu_seqlens_cpuduplicatescu_seqlens's description. Consider clarifying its purpose—it appears to be a CPU-side copy used to avoid GPU synchronization when preparing chunk indices.🔎 Suggested improvement
cu_seqlens_cpu (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. + CPU-resident copy of `cu_seqlens` of shape `[N+1]`. + When provided, avoids GPU-to-CPU synchronization during chunk index preparation.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
fla/ops/kda/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_inter.pyfla/ops/kda/chunk_intra.pyfla/ops/kda/chunk_intra_token_parallel.pyfla/ops/kda/gate.pyfla/ops/kda/wy_fast.pytests/ops/test_kda.py
🧰 Additional context used
🧬 Code graph analysis (3)
fla/ops/kda/chunk_inter.py (4)
fla/ops/utils/index.py (1)
prepare_chunk_indices(112-122)fla/utils.py (1)
check_shared_mem(454-460)fla/ops/kda/gate.py (2)
grid(215-216)grid(439-439)fla/ops/common/chunk_o.py (1)
grid(506-506)
fla/ops/kda/chunk.py (7)
fla/ops/common/chunk_delta_h.py (1)
chunk_gated_delta_rule_bwd_dhu(517-572)fla/ops/kda/chunk_bwd.py (1)
chunk_kda_bwd_dAv(275-323)fla/ops/kda/chunk_inter.py (1)
chunk_kda_bwd_dqkwg(142-191)fla/ops/kda/chunk_intra.py (2)
chunk_kda_bwd_intra(836-900)chunk_kda_fwd_intra(739-833)fla/ops/kda/gate.py (1)
kda_gate_chunk_cumsum(416-456)fla/ops/kda/wy_fast.py (1)
recompute_w_u_fwd(212-256)fla/ops/utils/index.py (1)
prepare_chunk_indices(112-122)
tests/ops/test_kda.py (2)
fla/ops/kda/gate.py (3)
naive_kda_lowerbound_gate(50-62)naive_kda_gate(19-47)backward(301-310)fla/ops/kda/chunk.py (1)
backward(307-379)
🔇 Additional comments (29)
fla/ops/kda/gate.py (1)
2-2: LGTM!Minor header comment update with no functional impact.
fla/ops/kda/chunk_bwd.py (1)
220-220: LGTM!Adding a comment to document the critical synchronization barrier is a good practice for maintainability.
fla/ops/kda/chunk_intra_token_parallel.py (2)
106-108: LGTM!Separating the computation from masking improves code clarity while maintaining functional equivalence.
111-114: Verify the index calculation change.The storage index changed from
j % BCtoj - i_ts. Given thatjiterates fromi_tsto at mosti_ts + BC - 1, both expressions should yield values in[0, BC). The new formj - i_tsis more explicit about the offset from the sub-chunk start.Please verify this change maintains correctness for edge cases where sequences don't align to chunk boundaries.
fla/ops/kda/chunk_inter.py (3)
1-14: LGTM!Standard copyright header, imports, and block size configuration consistent with other KDA modules.
31-139: Inter-chunk backward kernel implementation looks correct.The kernel correctly:
- Handles variable-length sequences via
IS_VARLEN- Computes gradient contributions
dq,dk,dw,dgusing the chain rule- Applies gating via
exp2for numerical stability- Uses block pointers with boundary checks
Minor observation:
m_tis computed at line 71 and redefined at line 124. While the second definition overwrites with identical logic, consider removing the redundant recomputation for clarity.
186-189: Bug: Incorrect parameter name in kernel call.The kernel parameter is named
K(line 49), but the call passesD=K. This will cause a parameter mismatch.🔎 Proposed fix
T=T, H=H, - D=K, + K=K, V=V, BT=BT,Likely an incorrect or invalid review comment.
tests/ops/test_kda.py (2)
150-167: LGTM!Test parameterization properly extended to cover new
safe_gateanddisable_recomputeflags with good coverage of different combinations.
196-203: LGTM!The
safe_gatelogic correctly:
- Sets
lower_boundwhen enabled- Clamps gate values for non-kernel gate path
- Selects appropriate naive gate function for reference comparison
fla/ops/kda/wy_fast.py (4)
22-22: LGTM!Adding
tf32x3precision option provides additional autotuning flexibility for TF32-capable hardware.
28-28: LGTM!Kernel rename to
recompute_w_u_fwd_kda_kernelimproves namespacing and distinguishes KDA-specific kernels.
116-116: LGTM!Kernel rename to
prepare_wy_repr_bwd_kda_kernelmaintains consistency with the forward kernel naming.
221-256: LGTM!The expanded return type
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]correctly reflects the addition of optionalqgandkgoutputs. The implementation properly handles the conditional allocation based on input parameters.fla/ops/kda/chunk.py (7)
10-17: LGTM!Import additions properly bring in the new modules needed for the extended KDA functionality.
32-47: LGTM!Forward function signature and parameter propagation correctly extended for
safe_gateanddisable_recomputesupport.
73-77: LGTM!Memory optimization pattern correctly implemented—intermediate tensors are set to
Nonewhendisable_recompute=Falseto allow garbage collection, while preserving them when recomputation is disabled.
99-122: LGTM!The conditional recomputation logic correctly handles both paths:
- When
disable_recompute=False: recomputesw,u,qg,kg,h,v_new- When
disable_recompute=True: retrieves cached values from**kwargsThis provides flexibility for memory vs compute trade-offs.
151-197: LGTM!The Triton version conditional properly gates between:
TRITON_ABOVE_3_4_0: Uses fusedchunk_kda_bwd_wy_dqkg_fusedkernel- Older versions: Falls back to separate
chunk_kda_bwd_dqkwg+prepare_wy_repr_bwdThis ensures compatibility across Triton versions while enabling optimizations on newer versions.
288-293: LGTM!Correctly nullifies
gwhen it will be recomputed in backward (whendisable_recompute=Falseanduse_gate_in_kernel=True), saving memory while preservingg_orgfor recomputation.
370-377: Duplicated cumsum call in else branch.The
chunk_local_cumsumcall in theelsebranch (lines 371-377) is identical to the call that occurs unconditionally before theif ctx.use_gate_in_kernelblock (lines 353-359). This means whenuse_gate_in_kernel=False, the cumsum is applied twice.Looking more carefully at lines 352-377: the
chunk_local_cumsumat lines 353-359 is inside theif ctx.use_gate_in_kernelblock. So the else branch is correct—it handles the non-kernel gate path. The indentation may be confusing but the logic is correct.fla/ops/kda/chunk_intra.py (9)
9-16: LGTM!Import updates and
SOLVE_TRIL_DOT_PRECISIONconstant properly set up TF32-aware precision control for matrix operations.
54-54: LGTM!Adding
USE_SAFE_GATEparameter enables conditional kernel behavior for the safe-gate optimization path.
143-195: LGTM!The nested conditionals for off-diagonal block computation are structurally sound. Each level (
i_tc2 < T,i_tc3 < T) properly guards its computations, and the accumulated values (b_Aqk20,b_Akk20, etc.) are correctly computed using the gated exponential differences.
244-277: Conditional forward substitution for safe_gate path.When
USE_SAFE_GATE=True, the forward substitution is skipped here because thechunk_kda_fwd_kernel_intra_sub_chunkkernel (lines 636-736) handles it. This is a valid optimization that avoids redundant computation.
464-504: LGTM!The
SAFE_GATEpath for diagonal block handling uses batched matrix operations instead of element-wise loops:
- Loads full diagonal blocks of
dAqkanddAkk- Applies proper masking (
m_i_diag_qk,m_j_diag_qk) for numerical stability- Computes gradients via matrix operations
This approach improves GPU utilization compared to the scalar loop in the
elsebranch.
560-609: LGTM!The
SAFE_GATEpath for k-transpose computation mirrors the structure of the qk-diagonal path:
- Uses
gatheror direct load based onUSE_GATHERcapability- Applies proper masking for boundary handling
- Maintains numerical stability through the gating exponential differences
623-736: New sub-chunk kernel for safe_gate forward path.The
chunk_kda_fwd_kernel_intra_sub_chunkkernel implements the intra-chunk computation with improved numerical stability:
- Uses a reference point
b_gnat sub-chunk midpoint for stable exponential computation- Computes
AqkandAkkwith proper masking- Includes forward substitution for matrix inverse
One observation: The forward substitution loop (lines 730-734) modifies
b_Aiiteratively using scalar loads fromAkk. Ensure the boundary conditionmin(BC, T - i_ti)correctly handles edge cases where the chunk extends beyond sequence length.
768-801: LGTM!The conditional dispatch correctly routes to:
chunk_kda_fwd_kernel_intra_sub_chunkwhensafe_gate=Truechunk_kda_fwd_intra_token_parallelwhensafe_gate=FalseThis enables the safe-gate optimization path while maintaining backward compatibility.
823-833: LGTM!Return signature correctly updated to
(w, u, qg, kg, Aqk, Akk)to include the newqgoutput fromrecompute_w_u_fwd.
Co-authored-by: Wang Yucheng <wangyucheng@moonshot.cn> Co-authored-by: yzhangcs <zhangyu02@moonshot.cn> Co-authored-by: Chen Guangyu <chenguangyu@moonshot.cn>
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
fla/ops/kda/chunk.py (1)
217-218: Fix typo in assertion message (duplicate issue).The assertion message says "sage_gate" but should say "safe_gate". This was flagged in a previous review.
🧹 Nitpick comments (2)
fla/ops/kda/chunk.py (2)
257-258: Use idiomatic boolean comparison.Per PEP 8, prefer
if not disable_recompute and use_gate_in_kernel:overif disable_recompute is False and use_gate_in_kernel:.🔎 Proposed fix
- if disable_recompute is False and use_gate_in_kernel: + if not disable_recompute and use_gate_in_kernel:
286-286: Use idiomatic boolean comparison.Same as above - prefer
if not ctx.disable_recompute and ctx.use_gate_in_kernel:.🔎 Proposed fix
- if ctx.disable_recompute is False and ctx.use_gate_in_kernel: + if not ctx.disable_recompute and ctx.use_gate_in_kernel:
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
fla/ops/kda/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_intra.pyfla/ops/kda/chunk_intra_token_parallel.pyfla/ops/kda/gate.pyfla/ops/kda/wy_fast.pytests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (2)
- fla/ops/kda/gate.py
- fla/ops/kda/chunk_bwd.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.
Applied to files:
fla/ops/kda/chunk_intra.py
🧬 Code graph analysis (3)
tests/ops/test_kda.py (2)
fla/ops/kda/gate.py (3)
naive_kda_lowerbound_gate(50-62)naive_kda_gate(19-47)backward(301-310)fla/ops/kda/chunk.py (1)
backward(276-348)
fla/ops/kda/chunk_intra.py (4)
fla/ops/utils/index.py (1)
prepare_chunk_indices(112-122)fla/ops/utils/op.py (1)
gather(25-31)fla/ops/kda/chunk_intra_token_parallel.py (2)
grid(152-152)chunk_kda_fwd_intra_token_parallel(117-169)fla/ops/kda/wy_fast.py (1)
recompute_w_u_fwd(212-256)
fla/ops/kda/chunk.py (5)
fla/ops/kda/gate.py (2)
kda_gate_bwd(234-272)kda_gate_chunk_cumsum(416-456)fla/ops/kda/wy_fast.py (1)
recompute_w_u_fwd(212-256)fla/ops/utils/index.py (1)
prepare_chunk_indices(112-122)fla/ops/kda/chunk_intra.py (1)
chunk_kda_fwd_intra(739-833)fla/ops/common/chunk_delta_h.py (1)
chunk_gated_delta_rule_fwd_h(464-514)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (21)
fla/ops/kda/chunk_intra_token_parallel.py (1)
106-114: LGTM - Refactored masking and indexing logic.The changes correctly:
- Apply the mask post-multiply via
tl.where(m_k[None, :], b_kgj, 0.0)instead of inline- Use multiplicative mask
* tl.where(j < i_t, 1.0, 0.0)forb_Akkcomputation- Change index from
j % BCtoj - i_tswhich is equivalent sincejranges fromi_tstoi_ts + BC - 1These are functionally equivalent transformations that align with the new gating/disable-recompute flow.
fla/ops/kda/wy_fast.py (3)
22-22: LGTM - Extended DOT_PRECISION autotune choices.Adding
"tf32"provides a useful middle-ground between"tf32x3"(higher precision, slower) and"ieee"(full precision, slowest), allowing the autotuner to find better performance/precision tradeoffs.
28-28: LGTM - Kernel renames for KDA-specific variants.Renaming to
recompute_w_u_fwd_kda_kernelandprepare_wy_repr_bwd_kda_kernelimproves clarity by distinguishing these KDA-specific kernels from potential generic variants.Also applies to: 116-116
221-256: LGTM - Extended return values for disable_recompute support.The function now returns
(w, u, qg, kg)whereqgandkgare optionally computed based on whetherqandgkare provided. This aligns with thedisable_recomputeflow where these intermediates need to be saved during forward pass to avoid recomputation in backward.tests/ops/test_kda.py (4)
150-168: LGTM - Extended test parameterization for gating controls.The test matrix now covers combinations of
safe_gateanddisable_recomputeflags, providing good coverage of the new code paths.
196-203: LGTM - Conditional gating logic for safe_gate.When
safe_gate=True, the test correctly:
- Sets
lower_bound = -5.0- Clamps
gto[-5, 0]when not using gate in kernel- Selects
naive_kda_lowerbound_gateas the reference function
263-264: Verify tolerance changes for gradient assertions.The tolerances for
dA(0.003) anddbias(0.008) appear tighter than before. Ensure these are appropriate for the precision characteristics of the new gating paths.
355-356: Reference uses naive_kda_gate but tests with use_gate_in_kernel.When
use_gate_in_kernel=Truein varlen tests, the reference path usesnaive_kda_gatebut the test doesn't account forsafe_gatein the reference computation. Since varlen tests withsafe_gate=Truerequireuse_gate_in_kernel=False, this is currently safe, but the logic could be clearer.fla/ops/kda/chunk.py (4)
31-46: LGTM - Extended forward signature and data flow.The forward function correctly:
- Accepts new
safe_gateanddisable_recomputeparameters- Passes them through to
chunk_kda_fwd_intra- Receives extended return values
(w, u, qg, kg, Aqk, Akk)
72-76: LGTM - Conditional memory cleanup.When
disable_recompute=False, intermediates are set toNoneto free memory since they'll be recomputed in backward. Whendisable_recompute=True, they're preserved for use in backward.
94-121: LGTM - Backward supports both recompute and cached paths.The backward function now:
- Recomputes
w, u, qg, kg, h, v_newwhendisable_recompute=False- Uses cached values from
kwargswhendisable_recompute=TrueThis provides flexibility between memory usage (recompute) and speed (cached).
347-348: Verify return tuple length matches forward parameters.The backward returns 17 values (including
Nones). Verify this matches the forward's positional arguments count:q, k, v, g, beta, A_log, dt_bias, scale, initial_state, output_final_state, use_qk_l2norm_in_kernel, use_gate_in_kernel, cu_seqlens, cu_seqlens_cpu, safe_gate, lower_bound, disable_recompute= 17 parameters. This looks correct.fla/ops/kda/chunk_intra.py (9)
9-16: LGTM - Updated imports and precision configuration.The imports are correctly updated to include
gather,IS_GATHER_SUPPORTED, andprepare_chunk_indices. TheSOLVE_TRIL_DOT_PRECISIONchange from'tf32x3'to'tf32'was noted in a previous review as an intentional precision-performance tradeoff.
54-54: LGTM - Added USE_SAFE_GATE parameter.The
USE_SAFE_GATEconstexpr controls whether to skip forward substitution when the sub-chunk kernel already handles the diagonal block inversion.
143-195: LGTM - Restructured off-diagonal block computation.The nested if-blocks for
i_tc2andi_tc3correctly compute off-diagonal Aqk and Akk blocks. The structure change improves code organization by grouping related computations together.
244-277: LGTM - Conditional forward substitution.When
USE_SAFE_GATE=False, the standard forward substitution is performed to compute the inverse. WhenUSE_SAFE_GATE=True, this is skipped because thechunk_kda_fwd_kernel_intra_sub_chunkkernel already computes the inverse directly.
464-504: Verify SAFE_GATE gradient computation correctness.The SAFE_GATE branch uses a block-based approach with masks
m_i_diag_qkandm_j_diag_qkinstead of the element-by-element loop. The computation:
- Loads diagonal dAqk/dAkk blocks
- Applies masks for valid positions
- Computes
exp2(b_g - b_gn)andexp2(-(b_g - b_gn))for numerical stability- Accumulates gradients via
tl.dotThis should be mathematically equivalent to the loop-based approach but more efficient.
623-737: LGTM - New intra sub-chunk kernel for safe gating.The
chunk_kda_fwd_kernel_intra_sub_chunkkernel:
- Computes Aqk and Akk for diagonal blocks using gating with numerical stability (subtracting
b_gnfrom middle of block)- Performs forward substitution inline to compute the inverse
- Uses
gatherwhen supported, otherwise falls back to offset-based loadingThe kernel correctly handles:
- Boundary checks via
m_cmask- Lower triangular masking via
m_Aqkandm_Akk- Identity addition for the inverse (
b_Ai += m_I)
768-801: LGTM - Conditional kernel dispatch based on safe_gate.When
safe_gate=True, uses the newchunk_kda_fwd_kernel_intra_sub_chunkkernel. Otherwise, uses the existingchunk_kda_fwd_intra_token_parallelpath. This provides flexibility without breaking existing functionality.
823-833: LGTM - Updated return signature.The function now returns
(w, u, qg, kg, Aqk, Akk)matching the updated caller expectations. Theqparameter torecompute_w_u_fwdis conditionally passed based ondisable_recomputeto computeqgonly when needed.
690-696: No action needed. The code correctly branches onUSE_GATHERbefore callinggather(). WhenIS_GATHER_SUPPORTED=False,USE_GATHERisFalse, so theelsebranch executes withtl.load()andgather()is never called. The fallbackgather()returningNoneis not executed in practice—it exists only to satisfy the Triton compiler. The logic is correct.Likely an incorrect or invalid review comment.
| if safe_gate: | ||
| assert use_gate_in_kernel is False | ||
| g = g.clamp(-5, 0) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's examine the test file around the mentioned lines
echo "=== Examining lines 273-313 in test_kda.py ==="
sed -n '273,313p' tests/ops/test_kda.py
echo -e "\n=== Searching for safe_gate=True in test_kda.py ==="
rg -n "safe_gate.*True" tests/ops/test_kda.py -A3 -B3Repository: fla-org/flash-linear-attention
Length of output: 1814
🏁 Script executed:
#!/bin/bash
# Search for documentation or comments about safe_gate constraint
echo "=== Searching for comments about safe_gate or use_gate_in_kernel constraint ==="
rg -n "safe_gate|use_gate_in_kernel" tests/ops/test_kda.py -B2 | head -50
echo -e "\n=== Checking for docstring in test_chunk_varlen function ==="
sed -n '281,295p' tests/ops/test_kda.py
echo -e "\n=== Searching for safe_gate in other files ==="
rg -l "safe_gate" --type py | head -10Repository: fla-org/flash-linear-attention
Length of output: 2478
🏁 Script executed:
#!/bin/bash
# Look at the constraint in the first test function (test_chunk) around line 196-198
echo "=== Context in test_chunk function around safe_gate constraint ==="
sed -n '190,205p' tests/ops/test_kda.py
echo -e "\n=== Checking implementation files for safe_gate constraint documentation ==="
rg -n "safe_gate" fla/ops/kda/chunk.py -B3 -A3 | head -50
echo -e "\n=== Checking chunk_intra.py ==="
rg -n "safe_gate" fla/ops/kda/chunk_intra.py -B2 -A2 | head -30Repository: fla-org/flash-linear-attention
Length of output: 3349
🏁 Script executed:
#!/bin/bash
# Find test_chunk parameters and check for safe_gate=True cases
echo "=== Extracting test_chunk parameters (lines 140-170) ==="
sed -n '140,170p' tests/ops/test_kda.py
echo -e "\n=== Count test cases with safe_gate=True across both test functions ==="
echo "test_chunk with safe_gate=True:"
sed -n '140,170p' tests/ops/test_kda.py | grep -c "True.*True" || echo "Checking..."
sed -n '140,170p' tests/ops/test_kda.py | tail -20
echo -e "\ntest_chunk_varlen with safe_gate=True:"
sed -n '273,278p' tests/ops/test_kda.pyRepository: fla-org/flash-linear-attention
Length of output: 2845
Add a descriptive message to the safe_gate assertion in the varlen test.
The assertion assert use_gate_in_kernel is False correctly enforces a varlen-specific constraint (unlike the non-varlen test_chunk which permits this combination), but it lacks an error message explaining why in-kernel gate computation is unsupported with safe gating for variable-length sequences. Update to: assert use_gate_in_kernel is False, "safe_gate is not supported with in-kernel gate computation for varlen".
🤖 Prompt for AI Agents
In tests/ops/test_kda.py around lines 311 to 313, the assertion for safe_gate
lacks a descriptive error message; update the assertion to include a clear
message such as "safe_gate is not supported with in-kernel gate computation for
varlen" so that if the test fails it explains that in-kernel gate computation is
unsupported with safe gating for variable-length sequences.
c4a9f93 to
e9a0605
Compare
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tests/ops/test_kda.py (1)
311-313: Add a descriptive message to the safe_gate assertion.The assertion
assert use_gate_in_kernel is Falsecorrectly enforces a constraint, but it lacks an error message explaining why in-kernel gate computation is unsupported with safe gating for variable-length sequences.🔎 Proposed fix
if safe_gate: - assert use_gate_in_kernel is False + assert use_gate_in_kernel is False, "safe_gate is not supported with in-kernel gate computation for varlen" g = g.clamp(-5, 0)fla/ops/kda/chunk.py (1)
219-221: Fix typo in assertion message."sage_gate" should be "safe_gate".
🔎 Proposed fix
if safe_gate: - assert lower_bound is not None, "lower_bound must be set when use sage_gate" + assert lower_bound is not None, "lower_bound must be set when using safe_gate"
🧹 Nitpick comments (1)
fla/ops/kda/chunk.py (1)
259-260: Use idiomatic boolean comparison.Per Python style guidelines (PEP 8), prefer
not disable_recomputeoverdisable_recompute is Falsefor boolean checks.🔎 Proposed fix
- if disable_recompute is False and use_gate_in_kernel: + if not disable_recompute and use_gate_in_kernel: g = None # type: ignore- if ctx.disable_recompute is False and ctx.use_gate_in_kernel: + if not ctx.disable_recompute and ctx.use_gate_in_kernel: g = kda_gate_chunk_cumsum(Also applies to: 288-289
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
fla/ops/kda/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_intra.pyfla/ops/kda/chunk_intra_token_parallel.pyfla/ops/kda/gate.pyfla/ops/kda/wy_fast.pytests/ops/test_kda.py
✅ Files skipped from review due to trivial changes (1)
- fla/ops/kda/gate.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.
Applied to files:
fla/ops/kda/chunk_intra.py
🧬 Code graph analysis (3)
tests/ops/test_kda.py (2)
fla/ops/kda/gate.py (3)
naive_kda_lowerbound_gate(50-62)naive_kda_gate(19-47)backward(301-310)fla/ops/kda/chunk.py (1)
backward(278-350)
fla/ops/kda/chunk.py (5)
fla/ops/kda/gate.py (1)
kda_gate_bwd(234-272)fla/ops/kda/wy_fast.py (1)
recompute_w_u_fwd(212-256)fla/ops/utils/cumsum.py (1)
chunk_local_cumsum(429-469)fla/ops/utils/index.py (1)
prepare_chunk_indices(112-122)fla/ops/kda/chunk_intra.py (1)
chunk_kda_fwd_intra(739-833)
fla/ops/kda/chunk_intra.py (3)
fla/ops/utils/op.py (1)
gather(25-31)fla/ops/kda/chunk_intra_token_parallel.py (2)
grid(152-152)chunk_kda_fwd_intra_token_parallel(117-169)fla/ops/kda/wy_fast.py (1)
recompute_w_u_fwd(212-256)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (20)
fla/ops/kda/chunk_bwd.py (1)
220-220: LGTM!Adding this cautionary comment to the
tl.debug_barrier()is a good practice. Synchronization barriers in Triton kernels are critical for correctness, and this annotation helps prevent accidental removal during future refactoring.fla/ops/kda/wy_fast.py (3)
22-22: Addingtf32precision option.Adding
"tf32"to the DOT_PRECISION choices alongside"tf32x3"and"ieee"is a reasonable extension. This provides an additional precision-performance tradeoff option. Based on a past review comment, this may improve performance but could reduce numerical precision compared totf32x3.
28-28: LGTM - Kernel renaming to KDA-specific variants.Renaming
recompute_w_u_fwd_kernel→recompute_w_u_fwd_kda_kernelandprepare_wy_repr_bwd_kernel→prepare_wy_repr_bwd_kda_kernelclearly indicates these are KDA-specific kernel implementations. This improves code organization and prevents confusion with potential non-KDA variants.Also applies to: 116-116
221-256: LGTM - Extended return type for recompute_w_u_fwd.The function now returns
(w, u, qg, kg)instead of just(w, u), properly propagating the gated query and key tensors needed whendisable_recompute=True. The conditional allocation ofqgandkgbased on input presence (lines 233-234) is memory-efficient.tests/ops/test_kda.py (4)
150-167: LGTM - Test parameterization extended for new features.Good coverage of
safe_gateanddisable_recomputecombinations across different test configurations. The test IDs properly include the new parameters for clear identification.
196-203: LGTM - Safe gate handling with lower_bound and naive gate function selection.Correctly sets
lower_bound = -5.0whensafe_gate=Trueand selects the appropriate naive gate function (naive_kda_lowerbound_gatevsnaive_kda_gate) for reference comparison. The clamping to[-5, 0]when not using in-kernel gate ensures consistency with the safe gate assumptions.
263-264: Note the tolerance adjustments for gradient checks.The tolerances for
dA(0.003) anddbias(0.008) are relatively tight. Consider whether these might need adjustment if tests become flaky on different hardware configurations.
355-356: Verify reference implementation uses correct gate function for varlen tests.When
use_gate_in_kernel=Truein varlen tests, the reference usesnaive_kda_gatebut notnaive_kda_lowerbound_gateeven whensafe_gatecould be True. However, the current parameterization at lines 273-277 showssafe_gate=Falsewhenuse_gate_in_kernel=Truefor varlen, so this is currently safe. If future tests combinesafe_gate=Truewithuse_gate_in_kernel=Truein varlen, this would need updating.fla/ops/kda/chunk_intra.py (6)
9-11: LGTM - Import updates for new functionality.The imports are correctly updated to include
prepare_chunk_indicesfrom utils, andgatherwithIS_GATHER_SUPPORTEDfor the conditional gather-based data access paths.
14-16: Precision change fromtf32x3totf32.This change from
'tf32x3'to'tf32'may improve performance but could reduce numerical precision. This aligns with the test tolerance adjustments observed in test_kda.py.
244-277: Safe gate path skips forward substitution on diagonals.When
USE_SAFE_GATE=True, the forward substitution loop (lines 253-277) is skipped. This relies on thechunk_kda_fwd_kernel_intra_sub_chunkkernel having already computed the appropriate diagonal inverse. Ensure this assumption is validated by the test coverage.
464-504: Safe gate backward path uses matrix operations instead of element-wise loop.The
SAFE_GATEbranch replaces the element-wise loop with vectorized matrix operations usingexp2(b_g - b_gn). This should be more efficient on tensor cores. The masking logic (m_i_diag_qk,m_j_diag_qk) correctly handles boundary conditions.
623-737: New intra-sub-chunk kernel for safe gate path.This new kernel
chunk_kda_fwd_kernel_intra_sub_chunkhandles the safe gate computation at sub-chunk granularity. Key observations:
- Lines 698-703: Uses a pivot point at
BC//2for numerical stability, keeping gate differences small- Lines 729-736: Performs forward substitution in-place after storing initial values
- The kernel correctly handles boundary conditions with
m_cmaskingThe
tl.debug_barrier()at line 723 ensures stores complete before the forward substitution reads from global memory.
767-801: Safe gate path selection in forward wrapper.Good conditional dispatch between the new
chunk_kda_fwd_kernel_intra_sub_chunkforsafe_gate=Trueand the existingchunk_kda_fwd_intra_token_parallelfor the standard path. The grid dimensions are correctly adjusted for the sub-chunk kernel.fla/ops/kda/chunk.py (5)
2-2: LGTM - Attribution added.Good practice to acknowledge contributing teams.
72-76: Memory cleanup for non-recompute path.Setting intermediate tensors to
Nonewhendisable_recompute=Falseis good for memory efficiency. The extended return tuple properly propagates all needed values for the backward pass whendisable_recompute=True.
94-123: Conditional recomputation in backward pass.The backward path correctly handles both cases:
- When
disable_recompute=False: recomputesw, u, qg, kg, h, v_new- When
disable_recompute=True: retrieves saved values fromkwargsThis provides a memory-compute tradeoff controlled by the user.
349-350: Verify backward return count matches forward input count.The backward returns 17 values. Verify this matches the forward's positional arguments:
- dq, 2. dk, 3. dv, 4. dg, 5. db, 6. dA, 7. dbias, 8. None (scale), 9. dh0, 10-17. None values for remaining args.
The forward has: q, k, v, g, beta, A_log, dt_bias, scale, initial_state, output_final_state, use_qk_l2norm_in_kernel, use_gate_in_kernel, cu_seqlens, cu_seqlens_cpu, safe_gate, lower_bound, disable_recompute (17 args).
The count matches.
406-420: LGTM - Well-documented new parameters.The docstring properly documents the new
cu_seqlens_cpu,safe_gate,lower_bound, anddisable_recomputeparameters with clear descriptions of their purpose and recommended usage patterns.fla/ops/kda/chunk_intra_token_parallel.py (1)
106-114: No action required—the index changes are correct.The review comment incorrectly identifies the index change. The code actually uses:
- Line 113:
j % BTfor the Aqk buffer (token-level, size BT)- Line 114:
j - i_tsfor the Akk buffer (sub-chunk-level, size BC)These are not interchangeable indices—they serve different purposes for buffers with different logical sizes. Both formulas are mathematically correct:
j % BTmaps the loop variable j to its position within a token, whilej - i_tsgives the offset within the sub-chunk (where j ranges fromi_tstomin(i_t + 1, min(T, i_ts + BC))). The computational changes (masking approach and multiplicative conditional) are also correct.
| if USE_GATHER: | ||
| b_gn = gather(b_g, tl.full([1, BK], min(BC//2, T - i_ti - 1), dtype=tl.int16), axis=0) | ||
| else: | ||
| # caculate offset | ||
| p_gn = g + (i_ti + min(BC // 2, T - i_ti - 1)) * H*K + tl.arange(0, BK) | ||
| b_gn = tl.load(p_gn, mask=tl.arange(0, BK) < K, other=0.0) | ||
| b_gn = b_gn[None, :] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check how IS_GATHER_SUPPORTED is defined
rg -n "IS_GATHER_SUPPORTED" --type py -B3 -A3 | head -60Repository: fla-org/flash-linear-attention
Length of output: 3259
🏁 Script executed:
#!/bin/bash
# Find the gather function in utils/op.py
rg -n "def gather" fla/ops/utils/ --type py -B2 -A15Repository: fla-org/flash-linear-attention
Length of output: 1093
🏁 Script executed:
#!/bin/bash
# Check the full context around lines 690-696 in chunk_intra.py
sed -n '680,705p' fla/ops/kda/chunk_intra.py | cat -nRepository: fla-org/flash-linear-attention
Length of output: 1282
🏁 Script executed:
#!/bin/bash
# Check for USE_GATHER definition and how it's set
rg -n "USE_GATHER\s*=" fla/ops/kda/chunk_intra.py --type py -B2 -A2Repository: fla-org/flash-linear-attention
Length of output: 321
🏁 Script executed:
#!/bin/bash
# Search for tests related to chunk_intra
find . -name "*test*" -type f -name "*.py" | xargs grep -l "chunk_intra" 2>/dev/null | head -10Repository: fla-org/flash-linear-attention
Length of output: 56
🏁 Script executed:
#!/bin/bash
# Check if there are tests that explicitly test USE_GATHER=False or IS_GATHER_SUPPORTED
rg -l "USE_GATHER|IS_GATHER_SUPPORTED" --type py tests/ 2>/dev/null | head -20Repository: fla-org/flash-linear-attention
Length of output: 56
🏁 Script executed:
#!/bin/bash
# Find all test files in the repository
find . -type f -name "test_*.py" -o -name "*_test.py" | head -20Repository: fla-org/flash-linear-attention
Length of output: 896
🏁 Script executed:
#!/bin/bash
# Look for tests directory
find . -type d -name "tests" -o -type d -name "test" | head -10Repository: fla-org/flash-linear-attention
Length of output: 81
🏁 Script executed:
#!/bin/bash
# Check if tests directory exists and what's in it
if [ -d "tests" ]; then ls -la tests/ 2>/dev/null | head -20; else echo "No tests directory found"; fiRepository: fla-org/flash-linear-attention
Length of output: 350
🏁 Script executed:
#!/bin/bash
# Check the KDA model test file
cat -n tests/models/test_modeling_kda.py | head -100Repository: fla-org/flash-linear-attention
Length of output: 1965
🏁 Script executed:
#!/bin/bash
# Look for ops tests related to chunk_intra or gather
ls -la tests/ops/ | grep -i kdaRepository: fla-org/flash-linear-attention
Length of output: 135
🏁 Script executed:
#!/bin/bash
# Check what's in ops tests directory
ls -la tests/ops/ | head -30Repository: fla-org/flash-linear-attention
Length of output: 1991
🏁 Script executed:
#!/bin/bash
# Check test_kda.py for any tests related to gather or USE_GATHER
rg -n "gather|USE_GATHER|IS_GATHER" tests/ops/test_kda.py --type pyRepository: fla-org/flash-linear-attention
Length of output: 56
🏁 Script executed:
#!/bin/bash
# Check the full test_kda.py file to understand test coverage
cat -n tests/ops/test_kda.py | head -150Repository: fla-org/flash-linear-attention
Length of output: 5815
🏁 Script executed:
#!/bin/bash
# Continue viewing test_kda.py to see all tests
wc -l tests/ops/test_kda.pyRepository: fla-org/flash-linear-attention
Length of output: 99
🏁 Script executed:
#!/bin/bash
# Check if there are any environment variable overrides or mocking for IS_GATHER_SUPPORTED in tests
rg -n "IS_GATHER_SUPPORTED|gather|USE_GATHER" tests/ -A3 -B3 --type pyRepository: fla-org/flash-linear-attention
Length of output: 1560
🏁 Script executed:
#!/bin/bash
# Check the conftest.py file for any fixtures related to gather
cat -n tests/ops/conftest.pyRepository: fla-org/flash-linear-attention
Length of output: 7928
Ensure the USE_GATHER=False fallback path is tested.
The scalar load fallback at lines 694-696 is correctly implemented and activates when IS_GATHER_SUPPORTED=False. However, the test suite (test_kda.py) does not explicitly exercise this fallback path—tests always use the IS_GATHER_SUPPORTED value at runtime. On hardware with gather support, this fallback branch never executes. Consider adding a test that mocks IS_GATHER_SUPPORTED=False or parameterizes tests to cover both paths.
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_intra.py around lines 690 to 696, the fallback branch used
when USE_GATHER/IS_GATHER_SUPPORTED is False (the scalar tl.load path) is not
covered by tests; update tests to explicitly exercise this branch by adding or
modifying test_kda.py to run the kernel with IS_GATHER_SUPPORTED
mocked/overridden to False (or parameterize the test to run with both True and
False), verify outputs/behavior match the gather-enabled path, and ensure any
necessary seeding or environment setup is applied so the scalar load path
executes deterministically during the test.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
fla/ops/kda/chunk.py (1)
483-487: Comprehensive validation for safe_gate mode.The validation correctly ensures that:
lower_boundis specified whensafe_gate=Truelower_boundis within the safe range[-5, 0)This prevents runtime errors from invalid configurations.
The static analysis tool suggests extracting the long exception messages into constants or custom exception classes to comply with TRY003. However, this is a minor style preference, and the current approach provides clear, contextualized error messages.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/ops/benchmark_kda.pyfla/ops/kda/chunk.pytests/ops/test_kda.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/ops/test_kda.py
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/ops/benchmark_kda.py (2)
fla/ops/kda/chunk.py (1)
backward(278-347)fla/ops/kda/gate.py (1)
backward(301-310)
🪛 Ruff (0.14.10)
fla/ops/kda/chunk.py
485-485: Avoid specifying long messages outside the exception class
(TRY003)
487-487: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (9)
benchmarks/ops/benchmark_kda.py (1)
44-44: LGTM: Direct provider comparison is clearer.The switch from
provider_base == 'xxx'toprovider == 'xxx'simplifies the logic and improves readability.Also applies to: 61-61, 73-73, 92-92, 110-110
fla/ops/kda/chunk.py (8)
2-2: LGTM: Import updates align with new gating logic.The copyright update and import changes (switching to
kda_gate_chunk_cumsumand addingprepare_chunk_indices) correctly support the new lower-bound-aware chunked gating functionality.Also applies to: 11-11, 15-15
31-32: LGTM: Memory-efficient recomputation strategy.The conditional memory cleanup at lines 72-76 provides a good trade-off: by default (
disable_recompute=False), intermediates are freed to save memory and recomputed during backward pass. Whendisable_recompute=True, intermediates are preserved for faster backward computation at the cost of memory.Also applies to: 34-47, 72-76
94-96: LGTM: Backward pass correctly handles recomputation.The conditional recomputation logic at lines 98-123 properly mirrors the forward pass strategy. When
disable_recompute=False, intermediates are recomputed; otherwise, they're retrieved from saved kwargs. The use of**kwargsat line 96 provides good backward compatibility.Also applies to: 98-123, 184-184
208-211: LGTM: Forward pass properly implements safe gating with lower bound.The forward pass correctly:
- Computes chunk indices from
cu_seqlenswhen provided- Validates
safe_gaterequirements (assertion at line 220, already addressed per past comments)- Applies
kda_gate_chunk_cumsumwithlower_boundwhen safe gating is enabled- Conditionally cleans up intermediate
gto save memory- Extends saved tensors to include all intermediates needed for the backward pass
Also applies to: 213-238, 244-257, 259-265, 267-267, 269-269, 272-272
283-287: LGTM: Backward pass correctly propagates new state and flags.The backward pass properly:
- Unpacks the extended saved tensors including intermediates
- Conditionally recomputes the gate when needed (lines 288-298)
- Threads
safe_gate,disable_recompute, and all intermediates through tochunk_kda_bwd- Correctly applies reverse cumsum in both gating branches
Also applies to: 288-298, 314-316, 324-345
357-366: LGTM: Well-documented API extensions.The public API signature is properly extended with type hints, and the docstring comprehensively documents the new parameters (
cu_seqlens_cpu,safe_gate,lower_bound,disable_recompute) including their interactions and constraints.Also applies to: 403-417
495-513: LGTM: ChunkKDAFunction.apply properly wired.All new parameters (
cu_seqlens_cpu,safe_gate,lower_bound,disable_recompute) are correctly passed to the autograd function in the proper order.
490-490: Assertion correctly enforces kernel's head dimension limit.The assertion
assert k.shape[-1] <= 256is properly aligned with the underlying kernel constraints. The kernelchunk_gated_delta_rule_fwd_kernel_h_blockdim64infla/ops/common/chunk_delta_h.pyenforces the same limit (lines 488, 536), with internal structure supporting up to 256 dimensions via four 64-element blocks.
load g 的时候手动 to fp32,解耦对 fp32 g 的需求
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
fla/ops/kda/chunk_intra.py (1)
690-696: Ensure theUSE_GATHER=Falsefallback path is tested.The scalar load fallback at lines 694-696 is correctly implemented and activates when
IS_GATHER_SUPPORTED=False. As noted in a past review, the test suite doesn't explicitly exercise this fallback. On hardware with gather support, this branch never executes. Consider adding a test that parameterizesIS_GATHER_SUPPORTEDto cover both paths.
🧹 Nitpick comments (3)
fla/ops/kda/chunk_bwd.py (1)
220-220: Barrier retention is important for correctness.The explicit "DO NOT REMOVE" comment on
tl.debug_barrier()suggests this is required for correct synchronization behavior due to Triton compiler sensitivities. Consider adding a brief comment explaining why removal breaks correctness (e.g., race condition on shared state, compiler optimization issue).fla/ops/common/chunk_delta_h.py (2)
145-162: Consider explicitly castingb_g_lastto FP32 for consistency.Line 149 explicitly casts
b_gtotl.float32, but line 147 loadsb_g_lastwithout an explicit cast. The subsequent subtractionb_g_last - b_g(lines 151, 154) will trigger implicit casting, but for consistency and clarity, consider explicitly castingb_g_last:b_g_last = tl.load(g + bos * H + last_idx * H + i_h).to(tl.float32)This ensures all gate values uniformly use FP32 for numerical operations and makes the dtype handling more explicit.
335-344: Consider explicitly castingbg_lastto FP32 for consistency.Similar to the forward kernel, line 338 explicitly casts
b_gtotl.float32, but line 336 loadsbg_lastwithout an explicit cast. For consistency with the explicit FP32 casting pattern and to avoid implicit type conversions inbg_last - b_g(line 387/389), consider:bg_last = tl.load(g + (bos + last_idx) * H + i_h).to(tl.float32)
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
fla/ops/common/chunk_delta_h.pyfla/ops/gla/chunk.pyfla/ops/kda/chunk_bwd.pyfla/ops/kda/chunk_intra.pyfla/ops/kda/wy_fast.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-07-24T11:34:30.618Z
Learnt from: KevlarKanou
Repo: fla-org/flash-linear-attention PR: 544
File: fla/utils.py:0-0
Timestamp: 2025-07-24T11:34:30.618Z
Learning: In `fla/utils.py`, the `TRITON_F32_DEFAULT` environment variable should be forcibly overridden to `'ieee'` for NVIDIA GPUs that don't support tf32 (device capability < 8), rather than using `setdefault()`. This override is required for compatibility with Turing GPUs and should not respect existing user settings.
Applied to files:
fla/ops/kda/chunk_intra.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (17)
fla/ops/gla/chunk.py (1)
352-352: LGTM! Explicit float32 casting ensures numerical precision.Casting
b_gtotl.float32before exponential operations ensures consistent precision regardless of input dtype, aligning with the broader PR objective of supporting FP16 gates while preserving numerical behavior.fla/ops/kda/chunk_bwd.py (1)
191-194: LGTM! Explicit float32 casting for gate values.Both
b_gandb_gnare now explicitly cast totl.float32, ensuring numerical stability for subsequentexp2operations regardless of the input tensor's dtype.fla/ops/kda/wy_fast.py (4)
22-22: Verify thetf32precision option aligns with numerical requirements.The autotuning now includes
"tf32"in addition to"tf32x3"and"ieee". Note thattf32provides lower precision thantf32x3(which uses 3 rounds of TF32). Given that this PR aims to maintain numerical stability, confirm that the autotuner selectingtf32overtf32x3won't degrade precision below acceptable thresholds for the use cases.
28-28: Kernel renamed to clarify KDA-specific usage.The rename from
recompute_w_u_fwd_kerneltorecompute_w_u_fwd_kda_kernelimproves clarity. Ensure any external consumers (if any) are updated accordingly.
82-82: LGTM! Float32 casting for gate computations.Explicit
.to(tl.float32)casts onb_gkandb_gnensure numerical stability in exponential operations, consistent with the PR's objective of supporting FP16 gates.Also applies to: 94-94
221-221: Extended return type enables recompute optimization.The function now returns a 4-tuple
(w, u, qg, kg)whereqgandkgare conditionally computed based ondisable_recompute. This allows the caller to cache precomputed gated values and avoid redundant computation in the backward pass.Also applies to: 233-234, 256-256
fla/ops/kda/chunk_intra.py (8)
14-16: Precision change fromtf32x3totf32.Changing
SOLVE_TRIL_DOT_PRECISIONfrom'tf32x3'to'tf32'may improve performance but can reduce numerical precision. This aligns with a past review comment noting this tradeoff. Given the PR's test tolerance adjustments, this appears intentional.
244-277: Forward substitution correctly bypassed for safe-gate path.When
USE_SAFE_GATEis true, the iterative forward substitution loop is skipped. This is correct because the newchunk_kda_fwd_kernel_intra_sub_chunkkernel (lines 636-737) handles diagonal block inversion internally, soAkkdalready contains the inverted blocks when the safe-gate path is used.
464-504: Safe-gate backward path uses block-wise computation for better Tensor Core utilization.The
SAFE_GATEbranch replaces per-element loops with block-wise matrix operations (tl.dot), which better utilizes Tensor Cores. The median-shift approach (BC//2) for numerical stability is consistent with the forward path.
560-609: Symmetric safe-gate handling in the dk backward computation.The second
SAFE_GATEblock mirrors the first, computingb_dktusing block-wise operations instead of element-wise loops. The consistent use ofmin(BC//2, T - i_ti - 1)for the median index ensures numerical stability.
729-736: Forward substitution in sub-chunk kernel produces inverted diagonal blocks.The loop correctly computes the inverse of the lower-triangular diagonal block in-place. Adding the identity matrix (
m_I) at the end and overwritingAkkensures the output is the full inverse. This enables the inter-solve kernel to skip its forward substitution whenUSE_SAFE_GATE=True.
767-801: Conditional kernel dispatch based onsafe_gate.The safe-gate path uses the new
chunk_kda_fwd_kernel_intra_sub_chunkkernel which computes diagonal blocks with improved numerical stability via median shift, while the non-safe-gate path continues using the token-parallel kernel. This branching correctly isolates the new behavior.
823-833: Extended return signature propagates precomputed gated values.The function now returns
(w, u, qg, kg, Aqk, Akk)whereqgandkgare conditionally computed whendisable_recompute=True. This allows the forward pass to cache these values for the backward pass, avoiding redundant computation.
892-898: Safe-gate and gather support flags correctly propagated to backward kernel.Both
SAFE_GATE=safe_gateandUSE_GATHER=IS_GATHER_SUPPORTEDare passed tochunk_kda_bwd_kernel_intra, ensuring the backward path uses matching logic to the forward path.fla/ops/common/chunk_delta_h.py (3)
164-192: LGTM! Consistent FP32 casting for gk loads.All
gktensor loads across the four 64-element key blocks are explicitly cast totl.float32. This ensures numerical stability in subsequentexp/exp2operations and enables FP16 gate support as intended.
355-382: LGTM! Backward kernel gk casting is consistent.All
gkloads in the backward kernel are explicitly cast totl.float32, mirroring the forward kernel pattern. This maintains numerical consistency and enables FP16 gate support throughout the forward and backward passes.
149-149: Excellent approach to enable FP16 gate support.The explicit
.to(tl.float32)casts after loading gate tensors effectively decouple the input tensor dtype from the computation dtype. This allows FP16 gates to be used for memory efficiency while maintaining FP32 precision for the critical exponential operations, preserving numerical stability. The pattern is consistently applied across both forward and backward kernels for allgkloads.Also applies to: 166-166, 173-173, 180-180, 187-187, 338-338, 357-357, 365-365, 373-373, 381-381
8d1659f to
990535b
Compare
|
Good optimization. Here is a question, in |
After profiling, we found that the original chunk_kda kernel is heavily bound by CUDA-core computation. To maintain numerical stability along the diagonal of the (causal-mask) attention matrix we had to split each 64-element chunk a second time—running a 16-element tile on Tensor Cores and looping over the remainder with CUDA cores. This introduced enormous overhead: in Triton those loops are hard to optimize, almost impossible to overlap, and, given the fixed FLOP/byte ratio of the hardware, inevitably slower.
By switching the activation function we enabled pure Tensor-Core execution with M = 16, which delivered a large end-to-end speed-up. To regain extra precision—and to keep the option of M = 32 Tensor Cores—we added a median shift to the log-domain cumulative-sum, reducing the accuracy loss at M = 16 while still allowing M = 32. Further Triton tweaks (e.g., transposing with M = 64 Tensor Cores) have so far yielded only marginal gains. At this point chunk_kda is a purely memory-bound kernel.
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.