Skip to content

[KDA] Fuse dAqk and dv#689

Merged
yzhangcs merged 2 commits intomainfrom
kda-bwd
Dec 18, 2025
Merged

[KDA] Fuse dAqk and dv#689
yzhangcs merged 2 commits intomainfrom
kda-bwd

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Dec 18, 2025

Summary by CodeRabbit

  • Performance Improvements

    • Hardware-aware autotuning and shared‑memory checks enable faster backward passes on modern GPUs.
    • Improved tiling and masking for better variable‑length sequence performance.
  • Refactor

    • Reworked backward computation flow to produce intermediate gradients more efficiently and simplify downstream processing.
  • New Features

    • Added a dedicated backward path that returns both matrix and vector gradient components for more consistent gradient propagation.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 18, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Backward KDA computation is refactored to compute dA and dv together via a new path. A Triton kernel and Python wrapper chunk_kda_bwd_dAv were added (Hopper-aware autotuning, var-length support), the previous in-kernel dA reconstruction and dv path were removed, and the legacy kernel was renamed for downstream use.

Changes

Cohort / File(s) Summary
Backward flow updates
fla/ops/kda/chunk.py
Rewired backward flow to call chunk_kda_bwd_dAv which returns (dAqk, dv); removed imports of legacy helpers (chunk_bwd_dv_local, chunk_gla_bwd_dA, chunk_gla_fwd_o_gk, chunk_inter); updated inline comments and dataflow to pass v_new downstream.
New kernel + wrapper and renames
fla/ops/kda/chunk_bwd.py
Added Triton kernel chunk_bwd_kernel_dAv and Python launcher chunk_kda_bwd_dAv (returns dA and dv jointly) with NVidia Hopper-aware autotuning, var-length sequence handling, tiling constants (BK/BV lists, NUM_WARPS), and shared-memory checks; renamed chunk_kda_bwd_kernel_interchunk_kda_bwd_kernel_dqkwg; removed old in-kernel dA/dv paths and updated public imports/signatures.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Verify numerical correctness of the new dA/dv joint path and equivalence to prior separate paths.
  • Inspect Triton kernel pointer arithmetic, masking, var-length handling, and tiling constants (BK_LIST, BV_LIST, NUM_WARPS).
  • Confirm consistency where v_new and returned (dAqk, dv) are consumed by downstream chunk_kda_bwd_dqkwg.

Possibly related PRs

Poem

🐰 I tunneled through kernels, found one tidy track,
dA and dv hopping now, bundled in my pack.
Hopper gears hum, autotuned and bright,
Chunked paths aligned in the soft moonlight.
A rabbit's small cheer for gradients done right. 🎋

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[KDA] Fuse dAqk and dv' directly describes the main change: fusing the computation of dAqk and dv into a single backward path via the new chunk_kda_bwd_dAv function.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kda-bwd

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yzhangcs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the backward pass for the Kernelized Deep Attention (KDA) mechanism. The primary goal is to optimize performance by fusing the gradient computations for dAqk and dv into a single operation. This involves introducing a new Triton kernel and corresponding Python wrapper, streamlining the KDA chunking operations and improving code organization.

Highlights

  • Fused Gradient Computation: The backward pass for Kernelized Deep Attention (KDA) now fuses the computation of dAqk and dv into a single, more efficient operation, replacing two separate calls.
  • New Triton Kernel: A new Triton kernel, chunk_bwd_kernel_dAv, has been introduced to handle the combined calculation of dAqk and dv.
  • Code Refactoring and Renaming: The file fla/ops/kda/chunk_inter.py has been renamed to fla/ops/kda/chunk_bwd.py for better clarity. Additionally, an existing Triton kernel within this module was renamed from chunk_kda_bwd_kernel_inter to chunk_kda_bwd_kernel_dqkwg.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by fusing the computation of dAqk and dv in the backward pass of KDA into a single Triton kernel. This is a solid improvement that should reduce kernel launch overhead and improve data locality. The implementation includes a new fused kernel, chunk_bwd_kernel_dAv, and its corresponding wrapper function. Additionally, chunk_inter.py has been aptly renamed to chunk_bwd.py for better code organization, which is a nice touch. The overall changes are well-structured and logical. I have one minor suggestion to improve the type hinting for better code correctness and maintainability.

Comment thread fla/ops/kda/chunk_bwd.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
fla/ops/kda/chunk_bwd.py (2)

44-46: Remove unused BK parameter or document its purpose.

The BK parameter is declared as tl.constexpr but never used within the kernel body. This parameter is passed from the wrapper but has no effect on the computation.

🔎 Suggested fix

If BK is not needed for this kernel (since it iterates over V dimension only), remove it:

     BT: tl.constexpr,
-    BK: tl.constexpr,
     BV: tl.constexpr,
     IS_VARLEN: tl.constexpr,

And update the kernel call in chunk_kda_bwd_dAv to not pass BK.


224-225: Use explicit | None type annotation.

Per PEP 484 and static analysis, scale: float = None should use explicit optional typing.

🔎 Suggested fix
-    scale: float = None,
+    scale: float | None = None,
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f5736b3 and 8406bd2.

📒 Files selected for processing (2)
  • fla/ops/kda/chunk.py (2 hunks)
  • fla/ops/kda/chunk_bwd.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
fla/ops/kda/chunk_bwd.py (2)
fla/utils.py (1)
  • check_shared_mem (447-453)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (118-123)
fla/ops/kda/chunk.py (2)
fla/ops/gla/chunk.py (1)
  • chunk_gla_fwd_o_gk (857-895)
fla/ops/kda/chunk_bwd.py (2)
  • chunk_kda_bwd_dAv (219-267)
  • chunk_kda_bwd_dqkwg (270-319)
🪛 Ruff (0.14.8)
fla/ops/kda/chunk_bwd.py

45-45: Unused function argument: BK

(ARG001)


225-225: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (6)
fla/ops/kda/chunk.py (2)

7-8: LGTM!

The imports are correctly updated to reflect the new fused backward path. The removal of chunk_bwd_dv_local, chunk_gla_bwd_dA, and chunk_inter aligns with the consolidated chunk_kda_bwd_dAv approach.


105-117: LGTM!

The fused backward path is well-structured. The inline comments clarify the mathematical operations (dAqk = do @ v.T, dv = A @ do), and the use of v_new (the transformed value after the delta rule forward) is consistent with the forward computation flow.

fla/ops/kda/chunk_bwd.py (4)

9-13: LGTM!

The hardware-specific autotuning configuration is well-organized. The conditional NUM_WARPS based on IS_NVIDIA_HOPPER and shared memory checks for BK_LIST/BV_LIST allow for architecture-specific optimization.


69-91: LGTM!

The kernel logic is correct:

  • b_dA accumulates do @ v.T across V-dimension slices (lines 80-84)
  • b_dv computes A.T @ do per V-slice, where the upper-triangular mask on b_A (line 71-72) effectively transposes the causal mask
  • The final dA storage correctly applies lower-triangular masking with scaling (line 90)

109-109: LGTM!

The kernel rename to chunk_kda_bwd_kernel_dqkwg is consistent with the wrapper function name and clearly indicates its purpose (computing gradients for q, k, w, g).


297-297: LGTM!

The kernel call correctly references the renamed kernel chunk_kda_bwd_kernel_dqkwg.

Comment thread fla/ops/kda/chunk_bwd.py
Comment on lines +235 to +240
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bug: Missing function call parentheses makes the else branch unreachable.

Line 237 checks elif check_shared_mem: which evaluates the function object itself (always truthy), not the function's return value. This makes CONST_TILING = 32 unreachable.

🔎 Proposed fix
     if check_shared_mem('hopper', k.device.index):
         CONST_TILING = 128
-    elif check_shared_mem:
+    elif check_shared_mem('ampere', k.device.index):
         CONST_TILING = 64
     else:
         CONST_TILING = 32

Or if you want to check for any shared memory capability without architecture specification:

-    elif check_shared_mem:
+    elif check_shared_mem():
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem:
CONST_TILING = 64
else:
CONST_TILING = 32
if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
elif check_shared_mem('ampere', k.device.index):
CONST_TILING = 64
else:
CONST_TILING = 32
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 235-240, the elif currently tests the
function object (elif check_shared_mem:) instead of calling it, making the final
else unreachable; change that line to call the function (e.g., elif
check_shared_mem(k.device.index):) so it evaluates the function's boolean result
and allows CONST_TILING = 32 to be reached when appropriate.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
fla/ops/kda/chunk_bwd.py (1)

235-240: Bug: Missing function call on line 237.

Line 237 evaluates check_shared_mem as a function object (always truthy) rather than calling it, making the else branch unreachable. This is the same issue flagged in the previous review.

🔎 Proposed fix
     if check_shared_mem('hopper', k.device.index):
         CONST_TILING = 128
-    elif check_shared_mem:
+    elif check_shared_mem('ampere', k.device.index):
         CONST_TILING = 64
     else:
         CONST_TILING = 32
🧹 Nitpick comments (2)
fla/ops/kda/chunk_bwd.py (2)

45-45: Consider removing unused parameter BK.

The BK parameter is declared but never used in the kernel body. While this doesn't affect functionality, removing it would clean up the signature.


225-225: Type hint style: Use explicit union with None.

Consider using float | None instead of implicit Optional for consistency with PEP 484.

🔎 Proposed fix
-    scale: float = None,
+    scale: float | None = None,
📜 Review details

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8406bd2 and 34c5220.

📒 Files selected for processing (1)
  • fla/ops/kda/chunk_bwd.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/kda/chunk_bwd.py (2)
fla/utils.py (1)
  • check_shared_mem (447-453)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (118-123)
🪛 Ruff (0.14.8)
fla/ops/kda/chunk_bwd.py

45-45: Unused function argument: BK

(ARG001)


225-225: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant