Skip to content

[DeltaNet] WY repr speedup#279

Merged
sustcsonglin merged 4 commits intomainfrom
deltanet
Apr 2, 2025
Merged

[DeltaNet] WY repr speedup#279
sustcsonglin merged 4 commits intomainfrom
deltanet

Conversation

@yzhangcs
Copy link
Copy Markdown
Member

@yzhangcs yzhangcs commented Apr 2, 2025

Tested on H100 machine

python -m benchmarks.benchmark_training_throughput \
  --name delta_net \
  --batch_size 1 \
  --context_len 4096 \
  --seq_len 32768 \
  --varlen \
  --steps 1024 
TGS
vanilla 47.9
save first 52.5

Summary by CodeRabbit

  • New Features

    • Introduced an additional input parameter to enhance data processing capabilities.
  • Refactor

    • Streamlined the internal logic for data mapping and conditional processing, improving reliability and overall performance.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2025

Walkthrough

The changes modify the fwd_prepare_wy_repr_kernel_chunk64 function in fla/ops/delta_rule/wy_fast.py by adding a new parameter At. Within the function, beta handling is split into two pointers (p_beta1 and p_beta2), and new pointers (p_A1 and p_A2) are introduced for processing data from At. The logic now conditionally zeroes values based on o_c comparisons, and the call from fwd_prepare_wy_repr is updated to pass the At tensor when BT equals 64.

Changes

File Change Summary
fla/ops/delta_rule/wy_fast.py - Updated fwd_prepare_wy_repr_kernel_chunk64 signature to include new parameter At.
- Modified beta handling: split into two pointers (p_beta1 and p_beta2) with corresponding loads (b_beta1 and b_beta2).
- Added pointers p_A1 and p_A2 to handle At data with conditional zeroing using o_c.
- Updated call in fwd_prepare_wy_repr to pass At when BT is 64.

Sequence Diagram(s)

sequenceDiagram
    participant Caller as fwd_prepare_wy_repr
    participant Kernel as fwd_prepare_wy_repr_kernel_chunk64
    Caller->>Kernel: Call with (k, beta, A, At, offsets, indices, T, H, K, BT, BK, BC, HEAD_FIRST, USE_OFFSETS)
    Kernel->>Kernel: Setup beta pointers (p_beta1, p_beta2)
    Kernel->>Kernel: Setup A pointers from A and At (p_A1, p_A2)
    Kernel->>Kernel: Apply conditional logic using o_c (zeroing when needed)
    Kernel->>Kernel: Compute and store results into p_A1 and p_A2
    Kernel-->>Caller: Return computed representation
Loading

Poem

(\_/)
I'm a code rabbit, hopping through the night,
Adding new paths for A's that shine so bright.
Two beta trails now guide my leap with glee,
Conditional hops over o_c set free.
In every byte and line, I celebrate with delight!

✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
fla/ops/delta_rule/wy_fast.py (1)

413-444: Split logic for BT == 64 vs. BT != 64
Having separate paths for chunk64 (using At) and chunk32 helps performance tuning but can lead to code duplication. Consider factoring out common components and dispatching specialized bits only if they diverge significantly.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9c5382b and f60b673.

📒 Files selected for processing (1)
  • fla/ops/delta_rule/wy_fast.py (4 hunks)
🔇 Additional comments (13)
fla/ops/delta_rule/wy_fast.py (13)

99-99: Good addition of At parameter
This new parameter is consistently applied throughout the kernel, enhancing flexibility for storing partial results. No immediate concerns.


119-119: Index creation looks fine
Using o_c = tl.arange(0, BC) is a straightforward way to manage column indices. No issues noted.


122-125: Verify block pointer offsets in HEAD_FIRST branch
Splitting beta into p_beta1, p_beta2 and similarly p_A1, p_A2 is logical. However, confirm that (i_t * BT + BC) correctly offsets the second portion of data when BC is smaller or bigger than half of BT.


127-131: Check pointer arithmetic in else branch
The approach mirrors the HEAD_FIRST logic. Ensure that the offsets (bos*H + i_h) * BC and (i_t * BT + BC) are correct for both partial blocks. Also verify that b_beta1 loads as intended under all boundary conditions.


134-134: Initialization of block matrices
Allocating b_A1 to zeros is a standard approach to avoid stale data in GPU kernels. No further concerns.


139-142: Pointers to k segments
Double-check the boundary checks for these pointers in both HEAD_FIRST and non-HEAD_FIRST paths. Confirm that each sub-block read matches the intended slice of k for chunked processing.


144-148: Computing partial correlation blocks
Combining b_k1 with b_beta1 and updating b_A1 is consistent. The operation with tl.trans(b_k1) and allow_tf32=False should provide correct building-block multiplications.


150-150: Cross-term accumulation
b_A3 += tl.dot(b_kb2, tl.trans(b_k1)) introduces a mixed product. Verify this cross-term is intentional and that dimensional alignment is correct.


152-153: Negative strict-lower-triangular extraction
Using -tl.where(o_c[:, None] > o_c[None, :], ...) flips the lower-triangular part. Confirm that this negative sign is consistent with your mathematical derivation.


154-156: Storing partial results and synchronization
Storing partial matrices to At ahead of the tl.debug_barrier() ensures a clean handoff between threads. This pattern looks correct.


159-170: Iterative partial pivot updates
The loop adjusts entries for rows ≥ 1. Please confirm that the masking with (i_t * BT + i < T) and (i_t * BT + BC + i < T) avoids out-of-bounds when BC + i surpasses T.


175-176: Including identity on the diagonal
Ensuring b_A1 and b_A2 have ones on the diagonal makes sense for further triangular/inverse operations.


189-189: Final store of b_A1
Assigning b_A1 back into global memory wraps up the partial computation. Looks consistent with the rest of the kernel’s flow.

@yzhangcs yzhangcs changed the title [DeltaNet] WY repr Speedup [DeltaNet] WY repr speedup Apr 2, 2025
@sustcsonglin sustcsonglin merged commit 2473f8f into main Apr 2, 2025
3 of 6 checks passed
@sustcsonglin sustcsonglin deleted the deltanet branch April 2, 2025 16:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants