Conversation
WalkthroughThis PR updates multiple modules related to chunk-based operations and delta rules. In the common operations module, an optional parameter Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant ChunkKernel
Caller->>ChunkKernel: call chunk_scaled_dot_kkt_fwd(g_cumsum, ...)
Note right of ChunkKernel: Check if USE_G is True
alt USE_G True
ChunkKernel->>ChunkKernel: Compute b_g_diff using g_cumsum
ChunkKernel->>ChunkKernel: Compute b_Ag from b_A & exp(b_g_diff)
else USE_G False
ChunkKernel->>ChunkKernel: Skip Ag computation
end
ChunkKernel->>Caller: Return A and Ag
sequenceDiagram
participant Caller
participant BackwardKernel
Caller->>BackwardKernel: call bwd_prepare_wy_repr_kernel(v, dw, du, dk, dv, dbeta, dg, ...)
BackwardKernel->>BackwardKernel: Initialize gradient accumulators (b_dA, b_dbeta, b_dA2)
BackwardKernel->>BackwardKernel: Compute backward gradients via tensor operations
BackwardKernel->>Caller: Return computed gradients
Possibly related PRs
Suggested reviewers
Poem
Tip ⚡💬 Agentic Chat (Pro Plan, General Availability)
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (9)
✅ Files skipped from review due to trivial changes (3)
🚧 Files skipped from review as they are similar to previous changes (2)
🧰 Additional context used🧬 Code Graph Analysis (1)fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py (2)
⏰ Context from checks skipped due to timeout of 90000ms (2)
🔇 Additional comments (20)
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
fla/ops/gated_delta_rule/wy_fast.py (1)
240-240: 🛠️ Refactor suggestionIntroduction of
bwd_prepare_wy_reprfunction.Provides the new public API to compute backward WY representation. Ensure that any references to the old forward-based kernel are removed throughout the codebase.
🧹 Nitpick comments (5)
fla/ops/utils/solve_tril.py (1)
163-174: Repeated float32 casting across larger blocks.For these merged blocks, confirm that casting all loaded data to
float32aligns with your precision needs. If half precision suffices, you can preserve performance by avoiding unneeded upcasting.fla/ops/gated_delta_rule/chunk.py (1)
30-30: Hard-coded chunk size of 64 inchunk_local_cumsum.Using a fixed chunk size simplifies usage but might reduce flexibility if future changes require varying chunk sizes. Consider making it configurable if future scenarios demand it.
fla/ops/common/chunk_scaled_dot_kkt.py (1)
31-31: New parametersg_cumsumandAg.Adding
g_cumsumto the kernel and returningAgexpand functionality. Verify that memory usage is acceptable when storing this extra tensor.Also applies to: 33-33
fla/ops/gated_delta_rule/wy_fast.py (2)
45-45: New constants for V and BV.Defines dimension for values. This brings flexibility in computing with the backward pass, but watch out for memory alignment on large V.
Also applies to: 48-48
252-252: FixedBT = 64in the backward interface.Hard-coded chunk size again. Consider allowing a parameter if future expansions or hardware constraints differ.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
fla/ops/common/chunk_scaled_dot_kkt.py(5 hunks)fla/ops/delta_rule/chunk.py(7 hunks)fla/ops/delta_rule/wy_fast.py(4 hunks)fla/ops/gated_delta_rule/chunk.py(4 hunks)fla/ops/gated_delta_rule/wy_fast.py(4 hunks)fla/ops/utils/solve_tril.py(3 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (4)
fla/ops/gated_delta_rule/chunk.py (5)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
chunk_scaled_dot_kkt_fwd(77-123)fla/ops/delta_rule/wy_fast.py (2)
bwd_prepare_wy_repr(247-288)fwd_recompute_w_u(209-244)fla/ops/gated_delta_rule/wy_fast.py (2)
bwd_prepare_wy_repr(240-286)fwd_recompute_w_u(201-237)fla/ops/utils/cumsum.py (1)
chunk_local_cumsum(386-406)fla/ops/utils/solve_tril.py (1)
solve_tril(222-276)
fla/ops/gated_delta_rule/wy_fast.py (4)
fla/ops/delta_rule/wy_fast.py (1)
bwd_prepare_wy_repr_kernel(91-177)fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py (1)
bwd_prepare_wy_repr_kernel(30-114)fla/ops/utils/op.py (1)
safe_exp(28-29)fla/ops/common/utils.py (1)
prepare_chunk_indices(59-64)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
fla/ops/utils/op.py (1)
safe_exp(28-29)
fla/ops/delta_rule/wy_fast.py (2)
fla/ops/utils/solve_tril.py (1)
solve_tril(222-276)fla/ops/common/chunk_scaled_dot_kkt.py (1)
chunk_scaled_dot_kkt_fwd(77-123)
⏰ Context from checks skipped due to timeout of 90000ms (1)
- GitHub Check: test
🔇 Additional comments (38)
fla/ops/utils/solve_tril.py (2)
51-51: Consider verifying performance impact of float32 casting.Casting this loaded block to
float32may improve numerical stability but could increase memory usage. Confirm that performance remains acceptable given your hardware constraints.
107-109: Same float32 casting note as above.Again, ensure that repeatedly casting these blocks to
float32is intended from both a memory footprint and compute performance perspective.fla/ops/gated_delta_rule/chunk.py (9)
13-15: Imports consolidated for chunk-based gating.No immediate issues. These new imports enable the forward/backward transformations using chunk-based logic.
28-28: Optional cu_seqlens parameter.Marking
cu_seqlensas optional reduces code complexity but ensure all call sites handleNonegracefully.
32-38: Usingchunk_scaled_dot_kkt_fwdwithg_cumsum.The call to compute both
AwandAufromg_cumsumis logical for the gating approach. Verify the correct dtype is returned as requested (torch.float32).
39-43:solve_trilusage forAw.Good approach to invert the lower triangular portion of
Aw. Double-check that all code paths handle shape[B, T, H, 16/32/64]as expected.
44-48:solve_trilusage forAu.Similar logic as for
Aw. EnsuresAuis also inverted if needed. Confirm that in-place modifications won’t conflict with parallel usage.
49-55: RecomputedwandufromAwandAu.Forward recomputation is consistent with the gating logic. No apparent issues; ensure shape alignment is tested.
129-129: Empty changed line.No actionable content.
155-155: Addition of gradients todg.Combining partial gradients is standard. Confirm that no additional scaling or normalization is needed in this step.
157-157: Reversing chunked cumsum for the gradient.This reversed cumsum ensures correct backprop for gating. Make sure test coverage verifies boundary conditions for variable-length inputs.
fla/ops/common/chunk_scaled_dot_kkt.py (7)
11-11: Importsafe_expfor exponent damping.This may help prevent overflow, but ensure large negative values are handled correctly.
15-16: New heuristics for variable-length and gating usage.
USE_Gis defined based ong_cumsum. Logic is straightforward and helps conditionally compute the gating matrix.
68-75: Conditional gating logic.When
USE_Gis true, the code computesb_Ag = b_A * safe_exp(b_g_diff). Make sureb_g_diffis bounded, or it could lead to NaN if differences are positive and large.
80-80: Optionalg_cumsumtyped annotation.No issues; consistent with your approach.
109-109: AllocateAgonly ifg_cumsumis not None.Saves memory when gating is not used; reasonable approach.
113-115: Passingg_cumsumandAgto the kernel in the launch call.Aligned with the new logic to compute the gating portion.
123-123: Returning(A, Ag)as a tuple.The introduction of
Agas a second matrix is consistent with the gating extension. Confirm that all call sites correctly handle the tuple return.fla/ops/gated_delta_rule/wy_fast.py (7)
21-21: Narrowing warp configurations from [2, 4] to optimize.Fewer warp sizes can improve tuning time but might limit performance on some GPUs.
24-24: Extended autotune key.Now includes
['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN']. This ensures specific dimension parameters drive kernel configuration.
27-27: Newbwd_prepare_wy_repr_kerneldefinition.Replaces the forward kernel with a backward approach for WY representation. This is a major addition; ensure thorough testing of backward pass correctness.
29-39: New parameters for backward pass (v, g, dw, du, dk, dv, dbeta, dg).These additions allow gradient accumulation and distribution across w, u, k, v, and gating components. Make sure each is sized properly at call time.
60-61: Initializing local gradient accumulators.Storing partial sums in
b_dbetaandb_dA. Ensure these are zeroed as intended before the partial dot accumulation.
71-75: Dot products for dW and dK.These lines accumulate partial derivatives with controlled TF32 usage. Confirm this matches the forward pass precision.
129-132: Storing computeddganddbeta.Completes gradient updates for gating and betas. Verify that no stride issues occur with large batch sizes.
fla/ops/delta_rule/wy_fast.py (4)
12-12: LGTM: Updated import path for solve_tril.The import has been updated to use the more modular approach from the utils package.
186-193: Added g_cumsum parameter and hardcoded chunk_size.The chunk_scaled_dot_kkt_fwd call has been updated to include the new g_cumsum parameter (set to None) and explicitly set chunk_size=64 rather than using a parameter. This aligns with the optimization focus of the PR where chunk_size is standardized across the codebase.
217-217: Hardcoded block tile size to 64.The variable BT is now directly set to 64 rather than being derived from the chunk_size parameter, which simplifies the implementation and makes the code more predictable.
257-257: Improved BT calculation based on input shape.Now BT is derived from A.shape[-1] instead of relying on the chunk_size parameter, which makes the code more robust as it adapts to the actual dimensions of the input tensor.
fla/ops/delta_rule/chunk.py (9)
28-33: LGTM: Simplified function call by removing chunk_size parameter.The call to fwd_prepare_wy_repr has been cleaned up by removing the chunk_size parameter, which is now handled internally in the called function.
34-42: Simplified chunk_gated_delta_rule_fwd_h call.The chunk_size parameter has been removed from the function call, streamlining the API. This is consistent with the optimization-focused changes across the codebase.
44-52: LGTM: Removed chunk_size from chunk_fwd_o call.Consistent with other changes, the chunk_size parameter has been removed from the function call, simplifying the interface.
68-74: Simplified fwd_recompute_w_u call by removing chunk_size parameter.The function call is now more straightforward with the removal of the chunk_size parameter, which is now handled internally within the function.
84-91: LGTM: Simplified chunk_bwd_dv_local call.The chunk_size parameter has been removed, which aligns with the optimization changes being made throughout the codebase.
92-103: Simplified chunk_gated_delta_rule_bwd_dhu call.The removal of the chunk_size parameter from this function call maintains consistency with the other optimization changes in the PR.
104-116: LGTM: Simplified chunk_bwd_dqkwg call.The chunk_size parameter has been removed from this call as well, maintaining a consistent approach to API simplification across the codebase.
117-125: Simplified bwd_prepare_wy_repr call.The chunk_size parameter has been removed, aligning with the pattern of simplifying function interfaces throughout this PR.
184-195: LGTM: Simplified backwards pass function call.The chunk_size parameter has been removed from the backward function call, completing the consistent simplification of the API across the codebase.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
fla/ops/common/chunk_delta_h.py (1)
363-369: Consolidated variable usage for K > 192 case with a minor inconsistency.While the variable reuse pattern is maintained, there seems to be an inconsistency with the
p_dvariable on line 366. Unlike previous sections where bothp_qandp_dare newly defined, here only a newb_dis loaded but without creating a newp_dpointer.For complete consistency with previous patterns, consider adding:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) +p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) b_d = tl.load(p_d, boundary_check=(0, 1))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/common/chunk_delta_h.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: test
- GitHub Check: test
🔇 Additional comments (5)
fla/ops/common/chunk_delta_h.py (5)
318-321: Code optimization: Consolidation of pointer and tensor variables.The code now reuses
p_kandb_kvariables instead of having separate variables for each condition, reducing redundancy and improving maintainability.
324-326: Consistent variable reuse pattern applied for K > 128 case.Good continuation of the variable consolidation pattern, maintaining consistency with the previous optimization.
329-331: Consistent variable reuse pattern applied for K > 192 case.Same optimization approach extended to the final case, completing the variable consolidation pattern throughout all conditional branches.
345-352: Comprehensive refactoring of the backward kernel logic for K > 64.The refactoring here extends beyond simple variable reuse to include a more streamlined approach for handling
qanddtensors in the K > 64 case. The code is now more consistent and better organized.
354-361: Consistent refactoring pattern applied for K > 128 case.This change maintains the same optimization pattern as the K > 64 case, creating a uniform approach throughout the function.
|
Hello! Is this optimization available by default for DeltaNet/GDN now ? And do you have a rough approximation of the speed up ? |
Summary by CodeRabbit
New Features
Refactor
Chores
Bug Fixes