Skip to content

[Gated DeltaNet] Refactor the kernel to remove one matrix inversion#433

Merged
sustcsonglin merged 3 commits intomainfrom
gated_deltanet
Jun 9, 2025
Merged

[Gated DeltaNet] Refactor the kernel to remove one matrix inversion#433
sustcsonglin merged 3 commits intomainfrom
gated_deltanet

Conversation

@sustcsonglin
Copy link
Copy Markdown
Collaborator

@sustcsonglin sustcsonglin commented Jun 8, 2025

Summary by CodeRabbit

  • Refactor

    • Simplified and unified the handling of gating and matrix operations in several attention-related functions, reducing intermediate tensors and streamlining tensor processing.
    • Removed preprocessing steps and redundant outputs in forward and backward passes, resulting in more efficient computation and memory usage.
    • Updated function signatures and internal logic to reflect these simplifications, with no changes required to user-facing APIs.
  • Bug Fixes

    • Improved consistency in tensor handling and gating logic across attention mechanisms, reducing the risk of mismatches or errors.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jun 8, 2025

Walkthrough

This update removes the preprocess_qkw Triton kernel and its usage, integrating gate-based scaling directly into the main forward and backward kernels for chunked delta rules. Related kernels and host functions are refactored to handle gating internally. Return values and signatures are updated throughout to reflect these changes, simplifying the control flow.

Changes

File(s) Change Summary
fla/ops/common/chunk_delta_h.py Removed preprocess_qkw kernel and its calls; integrated gating into main forward and backward kernels; renamed parameters; simplified autotuning.
fla/ops/common/chunk_scaled_dot_kkt.py Updated kernel and host function to return only A, removed Ag; gating now modifies A in-place.
fla/ops/delta_rule/wy_fast.py Adjusted to expect single return value from chunk_scaled_dot_kkt_fwd.
fla/ops/gated_delta_rule/chunk.py Unified handling of WY representation with single A tensor instead of Aw, Au; updated signatures and saved tensors.
fla/ops/gated_delta_rule/wy_fast.py Unified Aw and Au into A; added g argument; updated kernel logic for gating and recomputation; streamlined gradient computations.
fla/ops/path_attn/parallel.py Adjusted to expect single return value from chunk_scaled_dot_kkt_fwd.
fla/ops/common/chunk_o.py Modified gradient write-back condition to always write dw regardless of gating; removed gate-related gradient update block.

Sequence Diagram(s)

sequenceDiagram
    participant Host
    participant MainFwdKernel
    participant GateTensor

    Host->>MainFwdKernel: Launch forward kernel with k, w, q, g
    MainFwdKernel->>GateTensor: Load gate values (if USE_G)
    MainFwdKernel->>MainFwdKernel: Apply gating via safe_exp(g_last - g)
    MainFwdKernel->>Host: Return output tensors
Loading
sequenceDiagram
    participant Host
    participant ScaledDotKKTKernel

    Host->>ScaledDotKKTKernel: Launch kernel with k, beta, g_cumsum
    ScaledDotKKTKernel->>ScaledDotKKTKernel: Modify A in-place if USE_G
    ScaledDotKKTKernel->>Host: Return A
Loading

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

In the warren of code, a gate did appear,
Preprocessing vanished, the logic is clear.
Kernels now handle the gating with grace,
Fewer tensors to juggle, a streamlined embrace.
With a hop and a skip, the rabbits all cheer—
For simpler kernels are finally here!
🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a343221 and af35cea.

📒 Files selected for processing (4)
  • fla/ops/common/chunk_delta_h.py (13 hunks)
  • fla/ops/common/chunk_o.py (1 hunks)
  • fla/ops/gated_delta_rule/chunk.py (10 hunks)
  • fla/ops/gated_delta_rule/wy_fast.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/ops/gated_delta_rule/chunk.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/ops/common/chunk_delta_h.py (1)
fla/ops/utils/op.py (1)
  • safe_exp (28-29)
🪛 Pylint (3.3.7)
fla/ops/gated_delta_rule/wy_fast.py

[refactor] 192-192: Too many arguments (6/5)

(R0913)


[refactor] 192-192: Too many positional arguments (6/5)

(R0917)


[refactor] 192-192: Too many local variables (18/15)

(R0914)

⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (9)
fla/ops/common/chunk_o.py (1)

218-220: Verify the removal of gating condition for dw gradient storage.

The condition for storing dw gradient has been simplified from if USE_DW and not USE_G: to if USE_DW:. This means b_dw is now written back to dw regardless of the gating flag USE_G.

Please confirm this aligns with the broader refactoring to unify gating computations and that any necessary gating adjustments for dw are handled appropriately elsewhere in the codebase.

fla/ops/common/chunk_delta_h.py (4)

38-84: Parameter renaming is consistent with the refactoring.

The renaming from d to w throughout the forward kernel and function signatures is consistent with the naming conventions used across the gated delta rule implementations.

Also applies to: 425-428


148-161: Gating integration in forward kernel is well-implemented.

The integration of gating directly into the forward kernel using safe_exp(b_g_last - b_g)[:, None] is a clean implementation that ensures numerical stability. The placement of type conversion after gating multiplication is correct.


337-381: Backward kernel gating integration is consistent and correct.

The gating logic in the backward kernel correctly mirrors the forward pass:

  • b_dv is appropriately scaled by safe_exp(bg_last - b_g)[:, None]
  • Query gradients are multiplied by b_g_exp before scaling
  • The order of operations ensures numerical stability

29-30: Autotuning configuration optimization.

The reduction in autotuning configurations (removing BV=16 and limiting warps/stages) should improve compilation time while maintaining performance. This suggests empirical testing showed these configurations were optimal.

Also applies to: 201-203

fla/ops/gated_delta_rule/wy_fast.py (4)

77-84: Gating gradient computation appears correct.

The computation of gate gradients through b_dg accumulation is mathematically sound:

  • Line 77: b_k_beta_g correctly combines key, beta, and exponential gate
  • Lines 82-83: Gate gradients are accumulated with appropriate chain rule application

99-105: Complex gradient computation for unified matrix A.

The gradient computation for the unified matrix A involving multiple matrix operations and gating adjustments appears mathematically correct:

  • Lower triangular masking ensures causality
  • The safe_exp gating adjustment maintains numerical stability
  • The final gradient accumulation (lines 121-122) correctly computes diagonal adjustments

Also applies to: 121-122


173-189: Forward kernel correctly implements unified WY representation.

The forward kernel has been successfully adapted to use the unified matrix A and integrated gating:

  • Line 180: u computation uses A directly
  • Lines 187-188: w computation correctly incorporates both beta and exponential gating

192-201: Function signatures properly updated for unified approach.

The function signature updates correctly reflect the architectural changes:

  • Replacing Aw, Au with unified A and g_cumsum/g
  • Deriving BT from A.shape[-1] provides better flexibility

Also applies to: 231-242

🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 192-192: Too many arguments (6/5)

(R0913)


[refactor] 192-192: Too many positional arguments (6/5)

(R0917)


[refactor] 192-192: Too many local variables (18/15)

(R0914)

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (3)
fla/ops/gated_delta_rule/chunk.py (2)

204-225: ⚠️ Potential issue

Critical: Backward function incompatible with forward changes.

The backward function expects to unpack Aw, Au from saved tensors, but the forward pass now saves only A. Additionally, the calls to recompute_w_u_fwd and prepare_wy_repr_bwd use the old signatures with separate Aw and Au parameters.

The backward function needs to be updated to match the forward changes:

 def backward(
     ctx,
     do: torch.Tensor,
     dht: torch.Tensor
 ):
-    q, k, v, g, beta, Aw, Au, initial_state, cu_seqlens = ctx.saved_tensors
+    q, k, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
     if ctx.use_qk_l2norm_in_kernel:
         q, q_orig = l2norm_fwd(q), q
         k, k_orig = l2norm_fwd(k), k
     dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
         q=q,
         k=k,
         v=v,
         g=g,
         beta=beta,
-        Aw=Aw,
-        Au=Au,
+        A=A,
         scale=ctx.scale,
         initial_state=initial_state,
         do=do,
         dht=dht,
         cu_seqlens=cu_seqlens,
     )

Also, the chunk_gated_delta_rule_bwd function signature and its internal calls need corresponding updates.


73-94: ⚠️ Potential issue

Backward function signature and implementation need updating.

The chunk_gated_delta_rule_bwd function still uses the old signature with separate Aw and Au parameters and calls recompute_w_u_fwd with the outdated parameter names.

Update the function to use the consolidated A tensor:

 def chunk_gated_delta_rule_bwd(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     g: torch.Tensor,
     beta: torch.Tensor,
-    Aw: torch.Tensor,
-    Au: torch.Tensor,
+    A: torch.Tensor,
     scale: float,
     initial_state: torch.Tensor,
     do: torch.Tensor,
     dht: torch.Tensor,
     cu_seqlens: Optional[torch.LongTensor] = None,
 ):
     w, u = recompute_w_u_fwd(
         k=k,
         v=v,
         beta=beta,
-        Aw=Aw,
-        Au=Au,
+        A=A,
+        g_cumsum=g,
         cu_seqlens=cu_seqlens,
     )
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 73-73: Too many arguments (12/5)

(R0913)


[refactor] 73-73: Too many positional arguments (12/5)

(R0917)


[refactor] 73-73: Too many local variables (26/15)

(R0914)


[error] 87-94: Unexpected keyword argument 'Aw' in function call

(E1123)


[error] 87-94: Unexpected keyword argument 'Au' in function call

(E1123)


[error] 87-94: No value for argument 'g_cumsum' in function call

(E1120)


[error] 87-94: No value for argument 'A' in function call

(E1120)

fla/ops/common/chunk_scaled_dot_kkt.py (1)

82-103: 💡 Verification agent

🧩 Analysis chain

Verify callers are updated for the single return value.

The function now returns a single tensor A instead of a tuple (A, Ag). While the AI summary indicates all callers have been updated, please ensure all usage sites handle the new single return value correctly.

Run the following script to verify all callers have been updated:

Also applies to: 121-121


🏁 Script executed:

#!/bin/bash
# Description: Verify all calls to chunk_scaled_dot_kkt_fwd handle single return value

# Search for function calls and check for tuple unpacking
rg -A 2 'chunk_scaled_dot_kkt_fwd\(' --type py

Length of output: 1143


Fix tuple unpacking for updated return value

The function chunk_scaled_dot_kkt_fwd no longer returns (A, Ag), so any call sites that unpack two values must be updated:

• tests/ops/test_solve_tril.py

-    A, _ = chunk_scaled_dot_kkt_fwd(k, beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size)
+    A    = chunk_scaled_dot_kkt_fwd(k, beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size)

• fla/ops/path_attn/parallel.py

-        A, _ = chunk_scaled_dot_kkt_fwd(
-            k=w,
-            beta=beta,
-
-        )
+        A    = chunk_scaled_dot_kkt_fwd(
+            k=w,
+            beta=beta,
+
+        )

All other callers already use a single-assignment. Please update these two sites to avoid unpacking errors.

🧹 Nitpick comments (1)
fla/ops/common/chunk_delta_h.py (1)

462-480: Clean up commented preprocessing code.

The preprocessing code has been successfully removed. Consider removing these commented lines entirely to keep the codebase clean.

-    # if g is not None:
-    # q_new = torch.empty_like(q)
-    # k_new = torch.empty_like(k)
-    # w_new = torch.empty_like(w)
-    # def grid(meta): return (triton.cdiv(K, meta['BK']), N*H, triton.cdiv(T, BT))
-    # preprocess_qkw[grid](
-    #     q=q,
-    #     k=k,
-    #     w=w,
-    #     g=g,
-    #     q_new=q_new,
-    #     k_new=k_new,
-    #     w_new=w_new,
-    #     cu_seqlens=cu_seqlens,
-    #     T=T,
-    #     H=H,
-    #     K=K,
-    #     BT=BT,
-    # )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4c49281 and a343221.

📒 Files selected for processing (6)
  • fla/ops/common/chunk_delta_h.py (5 hunks)
  • fla/ops/common/chunk_scaled_dot_kkt.py (3 hunks)
  • fla/ops/delta_rule/wy_fast.py (1 hunks)
  • fla/ops/gated_delta_rule/chunk.py (4 hunks)
  • fla/ops/gated_delta_rule/wy_fast.py (3 hunks)
  • fla/ops/path_attn/parallel.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
fla/ops/delta_rule/wy_fast.py (1)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
  • chunk_scaled_dot_kkt_fwd (74-121)
fla/ops/common/chunk_delta_h.py (1)
fla/ops/utils/op.py (1)
  • safe_exp (28-29)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
fla/ops/utils/op.py (1)
  • safe_exp (28-29)
🪛 Pylint (3.3.7)
fla/ops/gated_delta_rule/wy_fast.py

[refactor] 198-198: Too many arguments (6/5)

(R0913)


[refactor] 198-198: Too many positional arguments (6/5)

(R0917)


[refactor] 198-198: Too many local variables (18/15)

(R0914)

⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (14)
fla/ops/path_attn/parallel.py (1)

34-40: LGTM!

The change correctly handles the updated return signature of chunk_scaled_dot_kkt_fwd, which now returns a single tensor instead of a tuple.

fla/ops/delta_rule/wy_fast.py (1)

186-193: LGTM!

The change correctly adapts to the new return signature of chunk_scaled_dot_kkt_fwd.

fla/ops/gated_delta_rule/chunk.py (3)

32-50: Forward pass correctly updated to use single A tensor.

The changes properly consolidate the WY representation into a single tensor A with integrated gating.


70-70: Return signature correctly updated.

The function now returns the consolidated A tensor instead of separate Aw and Au.


191-191: Forward correctly saves single A tensor.

The save_for_backward is updated to save the consolidated A tensor.

fla/ops/gated_delta_rule/wy_fast.py (3)

147-196: Forward kernel correctly updated for consolidated A tensor.

The kernel now properly handles the single A tensor with integrated gating through the g parameter.

🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 147-147: Too many arguments (17/5)

(R0913)


[refactor] 147-147: Too many positional arguments (17/5)

(R0917)


[refactor] 147-147: Too many local variables (42/15)

(R0914)


198-207: Function signature and implementation correctly updated.

The forward function now accepts g_cumsum and A parameters, and derives BT from A.shape appropriately.

🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 198-198: Too many arguments (6/5)

(R0913)


[refactor] 198-198: Too many positional arguments (6/5)

(R0917)


[refactor] 198-198: Too many local variables (18/15)

(R0914)


237-247: Verify backward function compatibility with forward changes.

The prepare_wy_repr_bwd function still expects separate Aw and Au parameters, which is inconsistent with the forward pass that now uses a single consolidated A tensor.

Please clarify whether:

  1. The backward function should be updated to accept a single A parameter
  2. The backward computation requires reconstructing Aw and Au from A and g
  3. The calling code in chunk_gated_delta_rule_bwd needs updating

This inconsistency will cause runtime errors when the backward pass is executed with the current forward implementation.

🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 237-237: Too many arguments (9/5)

(R0913)


[refactor] 237-237: Too many positional arguments (9/5)

(R0917)


[refactor] 237-237: Too many local variables (24/15)

(R0914)

fla/ops/common/chunk_scaled_dot_kkt.py (2)

25-25: LGTM! Appropriate autotuning key update.

Including USE_G in the autotuning key is correct since the kernel's computation path differs based on whether gating is enabled.


66-71: Efficient in-place gating computation.

The optimization correctly applies gating directly to b_A instead of computing a separate tensor. The use of safe_exp ensures numerical stability, and moving the masking/storage after the conditional block ensures the final (potentially gated) result is stored correctly.

fla/ops/common/chunk_delta_h.py (4)

11-11: Correct import for the new gating computation.

The addition of safe_exp import is necessary for the integrated gating computation in the kernel.


143-161: Well-integrated gating computation in the forward kernel.

The gating logic has been successfully integrated into the main kernel:

  • Proper storage of v_new with correct dtype conversion
  • Efficient loading of gate values using block pointers
  • Correct application of gating using safe_exp(b_g_last - b_g) to prevent numerical overflow
  • Appropriate ordering with dtype conversion after gating computation

This eliminates the need for preprocessing and reduces memory operations.


415-417: Correct parameter passing after preprocessing removal.

The forward kernel now receives the original tensors k, w (as d) directly, which is consistent with the integrated gating computation.


484-486: Consistent parameter passing in backward kernel.

The backward kernel now receives the original tensors q, k, w (as d) directly, maintaining consistency with the forward pass changes.

@sustcsonglin sustcsonglin changed the title [GDN kernel optimization] Refactor the kernel to remove one matrix inversion [Gated DeltaNet] Refactor the kernel to remove one matrix inversion Jun 9, 2025
@sustcsonglin sustcsonglin merged commit f03cb3a into main Jun 9, 2025
3 of 5 checks passed
@sustcsonglin sustcsonglin deleted the gated_deltanet branch June 9, 2025 05:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant