[Gated DeltaNet] Refactor the kernel to remove one matrix inversion#433
[Gated DeltaNet] Refactor the kernel to remove one matrix inversion#433sustcsonglin merged 3 commits intomainfrom
Conversation
WalkthroughThis update removes the Changes
Sequence Diagram(s)sequenceDiagram
participant Host
participant MainFwdKernel
participant GateTensor
Host->>MainFwdKernel: Launch forward kernel with k, w, q, g
MainFwdKernel->>GateTensor: Load gate values (if USE_G)
MainFwdKernel->>MainFwdKernel: Apply gating via safe_exp(g_last - g)
MainFwdKernel->>Host: Return output tensors
sequenceDiagram
participant Host
participant ScaledDotKKTKernel
Host->>ScaledDotKKTKernel: Launch kernel with k, beta, g_cumsum
ScaledDotKKTKernel->>ScaledDotKKTKernel: Modify A in-place if USE_G
ScaledDotKKTKernel->>Host: Return A
Possibly related PRs
Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (1)
🧰 Additional context used🧬 Code Graph Analysis (1)fla/ops/common/chunk_delta_h.py (1)
🪛 Pylint (3.3.7)fla/ops/gated_delta_rule/wy_fast.py[refactor] 192-192: Too many arguments (6/5) (R0913) [refactor] 192-192: Too many positional arguments (6/5) (R0917) [refactor] 192-192: Too many local variables (18/15) (R0914) ⏰ Context from checks skipped due to timeout of 90000ms (2)
🔇 Additional comments (9)
✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 0
🔭 Outside diff range comments (3)
fla/ops/gated_delta_rule/chunk.py (2)
204-225:⚠️ Potential issueCritical: Backward function incompatible with forward changes.
The backward function expects to unpack
Aw, Aufrom saved tensors, but the forward pass now saves onlyA. Additionally, the calls torecompute_w_u_fwdandprepare_wy_repr_bwduse the old signatures with separateAwandAuparameters.The backward function needs to be updated to match the forward changes:
def backward( ctx, do: torch.Tensor, dht: torch.Tensor ): - q, k, v, g, beta, Aw, Au, initial_state, cu_seqlens = ctx.saved_tensors + q, k, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors if ctx.use_qk_l2norm_in_kernel: q, q_orig = l2norm_fwd(q), q k, k_orig = l2norm_fwd(k), k dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( q=q, k=k, v=v, g=g, beta=beta, - Aw=Aw, - Au=Au, + A=A, scale=ctx.scale, initial_state=initial_state, do=do, dht=dht, cu_seqlens=cu_seqlens, )Also, the
chunk_gated_delta_rule_bwdfunction signature and its internal calls need corresponding updates.
73-94:⚠️ Potential issueBackward function signature and implementation need updating.
The
chunk_gated_delta_rule_bwdfunction still uses the old signature with separateAwandAuparameters and callsrecompute_w_u_fwdwith the outdated parameter names.Update the function to use the consolidated
Atensor:def chunk_gated_delta_rule_bwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, - Aw: torch.Tensor, - Au: torch.Tensor, + A: torch.Tensor, scale: float, initial_state: torch.Tensor, do: torch.Tensor, dht: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, ): w, u = recompute_w_u_fwd( k=k, v=v, beta=beta, - Aw=Aw, - Au=Au, + A=A, + g_cumsum=g, cu_seqlens=cu_seqlens, )🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 73-73: Too many arguments (12/5)
(R0913)
[refactor] 73-73: Too many positional arguments (12/5)
(R0917)
[refactor] 73-73: Too many local variables (26/15)
(R0914)
[error] 87-94: Unexpected keyword argument 'Aw' in function call
(E1123)
[error] 87-94: Unexpected keyword argument 'Au' in function call
(E1123)
[error] 87-94: No value for argument 'g_cumsum' in function call
(E1120)
[error] 87-94: No value for argument 'A' in function call
(E1120)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
82-103: 💡 Verification agent🧩 Analysis chain
Verify callers are updated for the single return value.
The function now returns a single tensor
Ainstead of a tuple(A, Ag). While the AI summary indicates all callers have been updated, please ensure all usage sites handle the new single return value correctly.Run the following script to verify all callers have been updated:
Also applies to: 121-121
🏁 Script executed:
#!/bin/bash # Description: Verify all calls to chunk_scaled_dot_kkt_fwd handle single return value # Search for function calls and check for tuple unpacking rg -A 2 'chunk_scaled_dot_kkt_fwd\(' --type pyLength of output: 1143
Fix tuple unpacking for updated return value
The function
chunk_scaled_dot_kkt_fwdno longer returns(A, Ag), so any call sites that unpack two values must be updated:• tests/ops/test_solve_tril.py
- A, _ = chunk_scaled_dot_kkt_fwd(k, beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size) + A = chunk_scaled_dot_kkt_fwd(k, beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size)• fla/ops/path_attn/parallel.py
- A, _ = chunk_scaled_dot_kkt_fwd( - k=w, - beta=beta, - … - ) + A = chunk_scaled_dot_kkt_fwd( + k=w, + beta=beta, + … + )All other callers already use a single-assignment. Please update these two sites to avoid unpacking errors.
🧹 Nitpick comments (1)
fla/ops/common/chunk_delta_h.py (1)
462-480: Clean up commented preprocessing code.The preprocessing code has been successfully removed. Consider removing these commented lines entirely to keep the codebase clean.
- # if g is not None: - # q_new = torch.empty_like(q) - # k_new = torch.empty_like(k) - # w_new = torch.empty_like(w) - # def grid(meta): return (triton.cdiv(K, meta['BK']), N*H, triton.cdiv(T, BT)) - # preprocess_qkw[grid]( - # q=q, - # k=k, - # w=w, - # g=g, - # q_new=q_new, - # k_new=k_new, - # w_new=w_new, - # cu_seqlens=cu_seqlens, - # T=T, - # H=H, - # K=K, - # BT=BT, - # )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
fla/ops/common/chunk_delta_h.py(5 hunks)fla/ops/common/chunk_scaled_dot_kkt.py(3 hunks)fla/ops/delta_rule/wy_fast.py(1 hunks)fla/ops/gated_delta_rule/chunk.py(4 hunks)fla/ops/gated_delta_rule/wy_fast.py(3 hunks)fla/ops/path_attn/parallel.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (3)
fla/ops/delta_rule/wy_fast.py (1)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
chunk_scaled_dot_kkt_fwd(74-121)
fla/ops/common/chunk_delta_h.py (1)
fla/ops/utils/op.py (1)
safe_exp(28-29)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
fla/ops/utils/op.py (1)
safe_exp(28-29)
🪛 Pylint (3.3.7)
fla/ops/gated_delta_rule/wy_fast.py
[refactor] 198-198: Too many arguments (6/5)
(R0913)
[refactor] 198-198: Too many positional arguments (6/5)
(R0917)
[refactor] 198-198: Too many local variables (18/15)
(R0914)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: test
- GitHub Check: test
🔇 Additional comments (14)
fla/ops/path_attn/parallel.py (1)
34-40: LGTM!The change correctly handles the updated return signature of
chunk_scaled_dot_kkt_fwd, which now returns a single tensor instead of a tuple.fla/ops/delta_rule/wy_fast.py (1)
186-193: LGTM!The change correctly adapts to the new return signature of
chunk_scaled_dot_kkt_fwd.fla/ops/gated_delta_rule/chunk.py (3)
32-50: Forward pass correctly updated to use single A tensor.The changes properly consolidate the WY representation into a single tensor
Awith integrated gating.
70-70: Return signature correctly updated.The function now returns the consolidated
Atensor instead of separateAwandAu.
191-191: Forward correctly saves single A tensor.The save_for_backward is updated to save the consolidated
Atensor.fla/ops/gated_delta_rule/wy_fast.py (3)
147-196: Forward kernel correctly updated for consolidated A tensor.The kernel now properly handles the single
Atensor with integrated gating through thegparameter.🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 147-147: Too many arguments (17/5)
(R0913)
[refactor] 147-147: Too many positional arguments (17/5)
(R0917)
[refactor] 147-147: Too many local variables (42/15)
(R0914)
198-207: Function signature and implementation correctly updated.The forward function now accepts
g_cumsumandAparameters, and derivesBTfromA.shapeappropriately.🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 198-198: Too many arguments (6/5)
(R0913)
[refactor] 198-198: Too many positional arguments (6/5)
(R0917)
[refactor] 198-198: Too many local variables (18/15)
(R0914)
237-247: Verify backward function compatibility with forward changes.The
prepare_wy_repr_bwdfunction still expects separateAwandAuparameters, which is inconsistent with the forward pass that now uses a single consolidatedAtensor.Please clarify whether:
- The backward function should be updated to accept a single
Aparameter- The backward computation requires reconstructing
AwandAufromAandg- The calling code in
chunk_gated_delta_rule_bwdneeds updatingThis inconsistency will cause runtime errors when the backward pass is executed with the current forward implementation.
🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 237-237: Too many arguments (9/5)
(R0913)
[refactor] 237-237: Too many positional arguments (9/5)
(R0917)
[refactor] 237-237: Too many local variables (24/15)
(R0914)
fla/ops/common/chunk_scaled_dot_kkt.py (2)
25-25: LGTM! Appropriate autotuning key update.Including
USE_Gin the autotuning key is correct since the kernel's computation path differs based on whether gating is enabled.
66-71: Efficient in-place gating computation.The optimization correctly applies gating directly to
b_Ainstead of computing a separate tensor. The use ofsafe_expensures numerical stability, and moving the masking/storage after the conditional block ensures the final (potentially gated) result is stored correctly.fla/ops/common/chunk_delta_h.py (4)
11-11: Correct import for the new gating computation.The addition of
safe_expimport is necessary for the integrated gating computation in the kernel.
143-161: Well-integrated gating computation in the forward kernel.The gating logic has been successfully integrated into the main kernel:
- Proper storage of
v_newwith correct dtype conversion- Efficient loading of gate values using block pointers
- Correct application of gating using
safe_exp(b_g_last - b_g)to prevent numerical overflow- Appropriate ordering with dtype conversion after gating computation
This eliminates the need for preprocessing and reduces memory operations.
415-417: Correct parameter passing after preprocessing removal.The forward kernel now receives the original tensors
k,w(asd) directly, which is consistent with the integrated gating computation.
484-486: Consistent parameter passing in backward kernel.The backward kernel now receives the original tensors
q,k,w(asd) directly, maintaining consistency with the forward pass changes.
Summary by CodeRabbit
Refactor
Bug Fixes