Skip to content

[Gated DeltaNet] optimize UT transform#349

Merged
yzhangcs merged 7 commits intomainfrom
fast-gdn
Apr 13, 2025
Merged

[Gated DeltaNet] optimize UT transform#349
yzhangcs merged 7 commits intomainfrom
fast-gdn

Conversation

@sustcsonglin
Copy link
Copy Markdown
Collaborator

@sustcsonglin sustcsonglin commented Apr 13, 2025

Summary by CodeRabbit

  • New Features

    • Enabled optional computation of an additional result when a cumulative adjustment input is provided.
  • Refactor

    • Simplified public interfaces by removing redundant size parameters, resulting in a cleaner API.
    • Revised internal logic for more efficient backward gradient computations.
    • Enhanced data type consistency through explicit conversions.
    • Introduced a new backward preparation kernel for the WY representation.
    • Consolidated pointer and tensor loading operations for improved clarity.
    • Renamed several functions for better clarity while maintaining functionality.
  • Chores

    • Updated module imports to include new functions in the public interface.
  • Bug Fixes

    • Improved clarity in test function calls by using keyword arguments for better readability.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 13, 2025

Walkthrough

This PR updates multiple modules related to chunk-based operations and delta rules. In the common operations module, an optional parameter g_cumsum is added along with a conditional control flow (using USE_G) to compute an extra output tensor. Several functions across delta rule and gated delta rule modules have their chunk_size parameter removed and hardcoded to 64 where needed, simplifying internal calculations. Moreover, the forward kernel for the WY representation in the gated delta rule module is replaced by a new backward kernel, and type conversions are enforced in the solve_tril utility.

Changes

File Change Summary
fla/ops/common/chunk_scaled_dot_kkt.py Added optional g_cumsum parameter and introduced USE_G heuristic to conditionally compute an additional tensor Ag; updated function signatures and return values accordingly.
fla/ops/delta_rule/chunk.py Removed the chunk_size parameter and its associated logic from chunk_delta_rule_fwd and chunk_delta_rule_bwd, simplifying function signatures and internal chunk size determination.
fla/ops/delta_rule/wy_fast.py Eliminated the chunk_size parameter from functions; replaced variable chunk size with a fixed value (64) in computations and updated calls to chunk_scaled_dot_kkt_fwd alongside modifying the import path for solve_tril.
fla/ops/gated_delta_rule/chunk.py Removed the chunk_size parameter in forward and backward functions; updated the forward pass to use chunk_scaled_dot_kkt_fwd for WY representation and adjusted calls to local cumsum functions with a fixed chunk size.
fla/ops/gated_delta_rule/wy_fast.py Replaced the forward WY kernel functions with a new backward kernel bwd_prepare_wy_repr_kernel; updated parameters, autotuning configurations, and internal gradient computations accordingly.
fla/ops/utils/solve_tril.py Added explicit .to(tl.float32) conversions on tensor loads to ensure type consistency and removed the previous assertion enforcing input tensor type.
fla/ops/utils/__init__.py Added import statement for solve_tril and included it in the __all__ list for module accessibility.
fla/ops/common/chunk_delta_h.py Consolidated pointer and tensor loading operations for k, q, and d tensors, reducing redundancy by reusing pointer and tensor variables across conditions based on K.
tests/ops/test_solve_tril.py Updated function call in test_solve_tril_varlen to use keyword arguments for clarity without altering logic.
fla/ops/generalized_delta_rule/dplr/chunk.py Renamed fwd_prepare_wy_repr to prepare_wy_repr_fwd in import and function calls within chunk_dplr_fwd and backward, preserving functionality.
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py Renamed bwd_prepare_wy_repr_kernel to prepare_wy_repr_bwd_kernel in definition and invocation, maintaining parameters and functionality.
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py Renamed multiple functions related to forward preparation; aliased original functions for backward compatibility, ensuring functionality remains unchanged.
fla/ops/generalized_delta_rule/iplr/chunk.py Updated import and function call from fwd_prepare_wy_repr to prepare_wy_repr_fwd, maintaining parameters and functionality.
fla/ops/generalized_delta_rule/iplr/wy_fast.py Renamed several functions and modified signatures, including new parameters and return types, while preserving core functionality.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant ChunkKernel
    Caller->>ChunkKernel: call chunk_scaled_dot_kkt_fwd(g_cumsum, ...)
    Note right of ChunkKernel: Check if USE_G is True
    alt USE_G True
        ChunkKernel->>ChunkKernel: Compute b_g_diff using g_cumsum
        ChunkKernel->>ChunkKernel: Compute b_Ag from b_A & exp(b_g_diff)
    else USE_G False
        ChunkKernel->>ChunkKernel: Skip Ag computation
    end
    ChunkKernel->>Caller: Return A and Ag
Loading
sequenceDiagram
    participant Caller
    participant BackwardKernel
    Caller->>BackwardKernel: call bwd_prepare_wy_repr_kernel(v, dw, du, dk, dv, dbeta, dg, ...)
    BackwardKernel->>BackwardKernel: Initialize gradient accumulators (b_dA, b_dbeta, b_dA2)
    BackwardKernel->>BackwardKernel: Compute backward gradients via tensor operations
    BackwardKernel->>Caller: Return computed gradients
Loading

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

I'm just a rabbit with speedy feet,
Hopping through code with beats so neat.
New parameters bloom like a spring so bright,
And functions now sing in the fixed 64 light.
A whiskered cheer for kernels that compute,
Our code garden grows, undeniably cute!
(_/), (•_•), (❁´◡`❁)

Tip

⚡💬 Agentic Chat (Pro Plan, General Availability)
  • We're introducing multi-step agentic chat in review comments and issue comments, within and outside of PR's. This feature enhances review and issue discussions with the CodeRabbit agentic chat by enabling advanced interactions, including the ability to create pull requests directly from comments and add commits to existing pull requests.

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5e39abb and acdd6c9.

📒 Files selected for processing (9)
  • fla/ops/delta_rule/chunk.py (7 hunks)
  • fla/ops/delta_rule/wy_fast.py (8 hunks)
  • fla/ops/gated_delta_rule/chunk.py (6 hunks)
  • fla/ops/gated_delta_rule/wy_fast.py (8 hunks)
  • fla/ops/generalized_delta_rule/dplr/chunk.py (3 hunks)
  • fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py (2 hunks)
  • fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py (9 hunks)
  • fla/ops/generalized_delta_rule/iplr/chunk.py (2 hunks)
  • fla/ops/generalized_delta_rule/iplr/wy_fast.py (4 hunks)
✅ Files skipped from review due to trivial changes (3)
  • fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py
  • fla/ops/generalized_delta_rule/dplr/chunk.py
  • fla/ops/generalized_delta_rule/iplr/chunk.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • fla/ops/gated_delta_rule/chunk.py
  • fla/ops/delta_rule/chunk.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py (2)
fla/ops/generalized_delta_rule/iplr/wy_fast.py (5)
  • prepare_wy_repr_fwd_kernel_chunk32 (28-68)
  • prepare_wy_repr_fwd_kernel_chunk64 (82-148)
  • wu_fwd_kernel (162-213)
  • wu_fwd (259-295)
  • prepare_wy_repr_fwd (216-256)
fla/ops/delta_rule/wy_fast.py (1)
  • prepare_wy_repr_fwd (180-206)
⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (20)
fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py (1)

27-27: Function renaming enhances clarity and consistency across the codebase.

The functions have been renamed to follow a more consistent naming convention:

  • prepare_wy_repr_fwd_kernel_chunk32 instead of fwd_prepare_wy_repr_kernel_chunk32
  • prepare_wy_repr_fwd_kernel_chunk64 instead of fwd_prepare_wy_repr_kernel_chunk64
  • wu_fwd_kernel instead of fwd_wu_kernel
  • wu_fwd instead of fwd_wu
  • prepare_wy_repr_fwd instead of fwd_prepare_wy_repr

This naming convention (putting the operation type after the function name) is more consistent with Python conventions and makes the code more readable.

Also applies to: 72-72, 151-151, 207-207, 225-225, 245-245, 259-259, 271-271, 282-284

fla/ops/delta_rule/wy_fast.py (4)

30-30: Function renaming and signature updates for consistency.

The function renaming follows the same pattern as in other files, improving code organization:

  • recompute_w_u_fwd_kernel (instead of fwd_recompute_w_u_kernel)
  • prepare_wy_repr_bwd_kernel (instead of bwd_prepare_wy_repr_kernel)
  • prepare_wy_repr_fwd (instead of fwd_prepare_wy_repr)

The signature of prepare_wy_repr_fwd has been updated to include the new g_cumsum=None parameter while removing the chunk_size parameter, which is now hardcoded to 64. This simplifies the interface by reducing the number of parameters while still supporting the new functionality.

Also applies to: 91-91, 180-185, 186-193


209-215: Simplified function signature with hardcoded chunk size.

The recompute_w_u_fwd function (renamed from fwd_recompute_w_u) now has a simplified signature with the chunk_size parameter removed. The chunk size is now hardcoded to 64 (line 217), which streamlines the code and removes one configuration option that may have been unnecessary.

Also applies to: 217-218


247-255: Simplified parameter handling in backward propagation.

The prepare_wy_repr_bwd function now derives BT directly from A.shape[-1] instead of relying on the removed chunk_size parameter. This is a more direct approach that ensures consistency between the forward and backward operations, as it uses the actual dimensions from the tensor created in the forward pass.

Also applies to: 257-258


291-296: Added alias definitions for backward compatibility.

The addition of these aliases ensures backward compatibility with code that might be using the old function names, which is a good practice during refactoring to prevent breaking existing code that depends on these functions.

fla/ops/gated_delta_rule/wy_fast.py (9)

27-40: Major refactoring of the backward kernel with additional parameters.

The backward kernel prepare_wy_repr_bwd_kernel now takes additional parameters:

  • v for processing value tensors
  • dw, du, dk, dv, dbeta, dg for gradient computation
  • Repositioned g parameter

This refactoring allows the backward kernel to handle both weight and value gradients in a single pass, which can improve computational efficiency.

Also applies to: 60-62


63-85: Improved gradient computation with refined tensor operations.

The gradient computation has been enhanced with more refined tensor operations:

  • Line 63-64: Using block pointers to efficiently access tensor data
  • Lines 65-66: Loading and preprocessing tensor data
  • Lines 74-79: Computing gradients for kernel weights
  • Lines 81-84: Applying causal mask and matrix multiplications for gradient chain rule

This approach ensures numerical stability and computational efficiency during backpropagation.


86-102: Added value tensor gradient computation.

New code segments have been added to compute gradients for value tensors:

  • Lines 86-88: Setting up block pointers and initializing tensors
  • Lines 90-101: Computing gradients for value tensors with proper scaling

This is a critical addition for end-to-end backpropagation in models using attention mechanisms.


103-113: Enhanced gradient processing with exponential scaling.

The gradient processing has been enhanced with exponential scaling:

  • Lines 103-106: Applying causal mask and matrix operations
  • Lines 108-113: Using safe exponential scaling based on gate values

This ensures numerical stability when dealing with potentially large values in the exponential space, which is important for gated architectures.


127-132: Final gradient aggregation and storage.

The final gradient aggregation and storage steps have been implemented:

  • Line 127: Additional scaling of gradients
  • Line 128: Computing gate gradients
  • Lines 129-132: Creating block pointers and storing the computed gradients

This completes the backward pass implementation with proper gradient calculation for all relevant tensors.


147-147: Consistent function renaming across the codebase.

Functions have been renamed to follow the consistent pattern established in other files:

  • recompute_w_u_fwd_kernel (instead of fwd_recompute_w_u_kernel)
  • recompute_w_u_fwd (instead of fwd_recompute_w_u)

This improves code organization and readability.

Also applies to: 201-201


214-216: Hardcoded chunk size for simplified configuration.

The chunk size parameter has been hardcoded to 64 across multiple functions:

  • Line 214-215: Changed variable initialization in recompute_w_u_fwd
  • Line 252-253: Set fixed chunk size in prepare_wy_repr_bwd

This simplification removes an unnecessary configuration parameter while still allowing the code to adapt to hardware constraints through the check_shared_mem() function.

Also applies to: 252-253


263-285: Updated kernel call with comprehensive gradient parameters.

The kernel call to prepare_wy_repr_bwd_kernel has been updated to include all gradient tensors required for the backward pass, matching the kernel's new parameter list. This ensures all necessary gradients are computed in a single kernel call, which is more efficient than multiple separate kernel calls.


289-291: Alias definitions for backward compatibility.

The alias definitions ensure backward compatibility with code that uses the old function names, which is a good practice during refactoring.

fla/ops/generalized_delta_rule/iplr/wy_fast.py (6)

28-28: Consistent function renaming for better code organization.

The functions have been renamed to follow the consistent pattern established in other files:

  • prepare_wy_repr_fwd_kernel_chunk32 (instead of fwd_prepare_wy_repr_kernel_chunk32)
  • prepare_wy_repr_fwd_kernel_chunk64 (instead of fwd_prepare_wy_repr_kernel_chunk64)
  • wu_fwd_kernel (instead of fwd_wu_kernel)

This improves code organization and readability.

Also applies to: 82-82, 162-162


216-224: Updated function signature with enhanced tensor initialization.

The prepare_wy_repr_fwd function signature has been updated:

  • Added b: torch.Tensor parameter
  • Changed return type to include the A tensor
  • Initialized BC (block count) variable for better chunk handling
  • Created the A tensor with proper dimensions
  • Added selection logic for the kernel function based on chunk size

These changes improve the function's flexibility and make it more consistent with the rest of the codebase.

Also applies to: 229-234


248-256: Refactored to use the dedicated wu_fwd function.

The code has been refactored to use the dedicated wu_fwd function instead of inline calculations:

  • Lines 248-255: Call to wu_fwd with appropriate parameters
  • Line 256: Updated return statement to include the A tensor

This separation of concerns improves code organization and maintainability.


259-266: New wu_fwd function with streamlined parameter handling.

A new wu_fwd function has been introduced with a clear separation of concerns:

  • Lines 259-266: Function signature with appropriate parameters and return type
  • Lines 267-275: Parameter handling and tensor dimension calculations

The function uses optimized tiling configuration based on hardware capabilities, which improves performance adaptability.

Also applies to: 267-278


278-295: Optimized kernel call with appropriate parameter mapping.

The kernel call to wu_fwd_kernel has been optimized:

  • Updated parameter order to match the kernel's signature
  • Properly mapped tensors to their corresponding parameters
  • Used hardware-aware tiling configurations for performance

This ensures efficient use of GPU resources during computation.


298-300: Added alias definitions for backward compatibility.

The alias definitions ensure backward compatibility with code that uses the old function names, which is a good practice during refactoring.

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
fla/ops/gated_delta_rule/wy_fast.py (1)

240-240: 🛠️ Refactor suggestion

Introduction of bwd_prepare_wy_repr function.

Provides the new public API to compute backward WY representation. Ensure that any references to the old forward-based kernel are removed throughout the codebase.

🧹 Nitpick comments (5)
fla/ops/utils/solve_tril.py (1)

163-174: Repeated float32 casting across larger blocks.

For these merged blocks, confirm that casting all loaded data to float32 aligns with your precision needs. If half precision suffices, you can preserve performance by avoiding unneeded upcasting.

fla/ops/gated_delta_rule/chunk.py (1)

30-30: Hard-coded chunk size of 64 in chunk_local_cumsum.

Using a fixed chunk size simplifies usage but might reduce flexibility if future changes require varying chunk sizes. Consider making it configurable if future scenarios demand it.

fla/ops/common/chunk_scaled_dot_kkt.py (1)

31-31: New parameters g_cumsum and Ag.

Adding g_cumsum to the kernel and returning Ag expand functionality. Verify that memory usage is acceptable when storing this extra tensor.

Also applies to: 33-33

fla/ops/gated_delta_rule/wy_fast.py (2)

45-45: New constants for V and BV.

Defines dimension for values. This brings flexibility in computing with the backward pass, but watch out for memory alignment on large V.

Also applies to: 48-48


252-252: Fixed BT = 64 in the backward interface.

Hard-coded chunk size again. Consider allowing a parameter if future expansions or hardware constraints differ.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6bfd5e6 and d744a78.

📒 Files selected for processing (6)
  • fla/ops/common/chunk_scaled_dot_kkt.py (5 hunks)
  • fla/ops/delta_rule/chunk.py (7 hunks)
  • fla/ops/delta_rule/wy_fast.py (4 hunks)
  • fla/ops/gated_delta_rule/chunk.py (4 hunks)
  • fla/ops/gated_delta_rule/wy_fast.py (4 hunks)
  • fla/ops/utils/solve_tril.py (3 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (4)
fla/ops/gated_delta_rule/chunk.py (5)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
  • chunk_scaled_dot_kkt_fwd (77-123)
fla/ops/delta_rule/wy_fast.py (2)
  • bwd_prepare_wy_repr (247-288)
  • fwd_recompute_w_u (209-244)
fla/ops/gated_delta_rule/wy_fast.py (2)
  • bwd_prepare_wy_repr (240-286)
  • fwd_recompute_w_u (201-237)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (386-406)
fla/ops/utils/solve_tril.py (1)
  • solve_tril (222-276)
fla/ops/gated_delta_rule/wy_fast.py (4)
fla/ops/delta_rule/wy_fast.py (1)
  • bwd_prepare_wy_repr_kernel (91-177)
fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py (1)
  • bwd_prepare_wy_repr_kernel (30-114)
fla/ops/utils/op.py (1)
  • safe_exp (28-29)
fla/ops/common/utils.py (1)
  • prepare_chunk_indices (59-64)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
fla/ops/utils/op.py (1)
  • safe_exp (28-29)
fla/ops/delta_rule/wy_fast.py (2)
fla/ops/utils/solve_tril.py (1)
  • solve_tril (222-276)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
  • chunk_scaled_dot_kkt_fwd (77-123)
⏰ Context from checks skipped due to timeout of 90000ms (1)
  • GitHub Check: test
🔇 Additional comments (38)
fla/ops/utils/solve_tril.py (2)

51-51: Consider verifying performance impact of float32 casting.

Casting this loaded block to float32 may improve numerical stability but could increase memory usage. Confirm that performance remains acceptable given your hardware constraints.


107-109: Same float32 casting note as above.

Again, ensure that repeatedly casting these blocks to float32 is intended from both a memory footprint and compute performance perspective.

fla/ops/gated_delta_rule/chunk.py (9)

13-15: Imports consolidated for chunk-based gating.

No immediate issues. These new imports enable the forward/backward transformations using chunk-based logic.


28-28: Optional cu_seqlens parameter.

Marking cu_seqlens as optional reduces code complexity but ensure all call sites handle None gracefully.


32-38: Using chunk_scaled_dot_kkt_fwd with g_cumsum.

The call to compute both Aw and Au from g_cumsum is logical for the gating approach. Verify the correct dtype is returned as requested (torch.float32).


39-43: solve_tril usage for Aw.

Good approach to invert the lower triangular portion of Aw. Double-check that all code paths handle shape [B, T, H, 16/32/64] as expected.


44-48: solve_tril usage for Au.

Similar logic as for Aw. Ensures Au is also inverted if needed. Confirm that in-place modifications won’t conflict with parallel usage.


49-55: Recomputed w and u from Aw and Au.

Forward recomputation is consistent with the gating logic. No apparent issues; ensure shape alignment is tested.


129-129: Empty changed line.

No actionable content.


155-155: Addition of gradients to dg.

Combining partial gradients is standard. Confirm that no additional scaling or normalization is needed in this step.


157-157: Reversing chunked cumsum for the gradient.

This reversed cumsum ensures correct backprop for gating. Make sure test coverage verifies boundary conditions for variable-length inputs.

fla/ops/common/chunk_scaled_dot_kkt.py (7)

11-11: Import safe_exp for exponent damping.

This may help prevent overflow, but ensure large negative values are handled correctly.


15-16: New heuristics for variable-length and gating usage.

USE_G is defined based on g_cumsum. Logic is straightforward and helps conditionally compute the gating matrix.


68-75: Conditional gating logic.

When USE_G is true, the code computes b_Ag = b_A * safe_exp(b_g_diff). Make sure b_g_diff is bounded, or it could lead to NaN if differences are positive and large.


80-80: Optional g_cumsum typed annotation.

No issues; consistent with your approach.


109-109: Allocate Ag only if g_cumsum is not None.

Saves memory when gating is not used; reasonable approach.


113-115: Passing g_cumsum and Ag to the kernel in the launch call.

Aligned with the new logic to compute the gating portion.


123-123: Returning (A, Ag) as a tuple.

The introduction of Ag as a second matrix is consistent with the gating extension. Confirm that all call sites correctly handle the tuple return.

fla/ops/gated_delta_rule/wy_fast.py (7)

21-21: Narrowing warp configurations from [2, 4] to optimize.

Fewer warp sizes can improve tuning time but might limit performance on some GPUs.


24-24: Extended autotune key.

Now includes ['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN']. This ensures specific dimension parameters drive kernel configuration.


27-27: New bwd_prepare_wy_repr_kernel definition.

Replaces the forward kernel with a backward approach for WY representation. This is a major addition; ensure thorough testing of backward pass correctness.


29-39: New parameters for backward pass (v, g, dw, du, dk, dv, dbeta, dg).

These additions allow gradient accumulation and distribution across w, u, k, v, and gating components. Make sure each is sized properly at call time.


60-61: Initializing local gradient accumulators.

Storing partial sums in b_dbeta and b_dA. Ensure these are zeroed as intended before the partial dot accumulation.


71-75: Dot products for dW and dK.

These lines accumulate partial derivatives with controlled TF32 usage. Confirm this matches the forward pass precision.


129-132: Storing computed dg and dbeta.

Completes gradient updates for gating and betas. Verify that no stride issues occur with large batch sizes.

fla/ops/delta_rule/wy_fast.py (4)

12-12: LGTM: Updated import path for solve_tril.

The import has been updated to use the more modular approach from the utils package.


186-193: Added g_cumsum parameter and hardcoded chunk_size.

The chunk_scaled_dot_kkt_fwd call has been updated to include the new g_cumsum parameter (set to None) and explicitly set chunk_size=64 rather than using a parameter. This aligns with the optimization focus of the PR where chunk_size is standardized across the codebase.


217-217: Hardcoded block tile size to 64.

The variable BT is now directly set to 64 rather than being derived from the chunk_size parameter, which simplifies the implementation and makes the code more predictable.


257-257: Improved BT calculation based on input shape.

Now BT is derived from A.shape[-1] instead of relying on the chunk_size parameter, which makes the code more robust as it adapts to the actual dimensions of the input tensor.

fla/ops/delta_rule/chunk.py (9)

28-33: LGTM: Simplified function call by removing chunk_size parameter.

The call to fwd_prepare_wy_repr has been cleaned up by removing the chunk_size parameter, which is now handled internally in the called function.


34-42: Simplified chunk_gated_delta_rule_fwd_h call.

The chunk_size parameter has been removed from the function call, streamlining the API. This is consistent with the optimization-focused changes across the codebase.


44-52: LGTM: Removed chunk_size from chunk_fwd_o call.

Consistent with other changes, the chunk_size parameter has been removed from the function call, simplifying the interface.


68-74: Simplified fwd_recompute_w_u call by removing chunk_size parameter.

The function call is now more straightforward with the removal of the chunk_size parameter, which is now handled internally within the function.


84-91: LGTM: Simplified chunk_bwd_dv_local call.

The chunk_size parameter has been removed, which aligns with the optimization changes being made throughout the codebase.


92-103: Simplified chunk_gated_delta_rule_bwd_dhu call.

The removal of the chunk_size parameter from this function call maintains consistency with the other optimization changes in the PR.


104-116: LGTM: Simplified chunk_bwd_dqkwg call.

The chunk_size parameter has been removed from this call as well, maintaining a consistent approach to API simplification across the codebase.


117-125: Simplified bwd_prepare_wy_repr call.

The chunk_size parameter has been removed, aligning with the pattern of simplifying function interfaces throughout this PR.


184-195: LGTM: Simplified backwards pass function call.

The chunk_size parameter has been removed from the backward function call, completing the consistent simplification of the API across the codebase.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
fla/ops/common/chunk_delta_h.py (1)

363-369: Consolidated variable usage for K > 192 case with a minor inconsistency.

While the variable reuse pattern is maintained, there seems to be an inconsistency with the p_d variable on line 366. Unlike previous sections where both p_q and p_d are newly defined, here only a new b_d is loaded but without creating a new p_d pointer.

For complete consistency with previous patterns, consider adding:

 p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
+p_d = tl.make_block_ptr(d, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
 b_q = tl.load(p_q, boundary_check=(0, 1))
 b_q = (b_q * scale).to(b_q.dtype)
 b_d = tl.load(p_d, boundary_check=(0, 1))
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d478ca0 and 817398a.

📒 Files selected for processing (1)
  • fla/ops/common/chunk_delta_h.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: test
  • GitHub Check: test
🔇 Additional comments (5)
fla/ops/common/chunk_delta_h.py (5)

318-321: Code optimization: Consolidation of pointer and tensor variables.

The code now reuses p_k and b_k variables instead of having separate variables for each condition, reducing redundancy and improving maintainability.


324-326: Consistent variable reuse pattern applied for K > 128 case.

Good continuation of the variable consolidation pattern, maintaining consistency with the previous optimization.


329-331: Consistent variable reuse pattern applied for K > 192 case.

Same optimization approach extended to the final case, completing the variable consolidation pattern throughout all conditional branches.


345-352: Comprehensive refactoring of the backward kernel logic for K > 64.

The refactoring here extends beyond simple variable reuse to include a more streamlined approach for handling q and d tensors in the K > 64 case. The code is now more consistent and better organized.


354-361: Consistent refactoring pattern applied for K > 128 case.

This change maintains the same optimization pattern as the K > 64 case, creating a uniform approach throughout the function.

@yzhangcs yzhangcs merged commit 465bc33 into main Apr 13, 2025
3 of 6 checks passed
@yzhangcs yzhangcs deleted the fast-gdn branch April 13, 2025 07:33
@alxndrTL
Copy link
Copy Markdown

Hello! Is this optimization available by default for DeltaNet/GDN now ? And do you have a rough approximation of the speed up ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants