Skip to content

[kda kernel optimization] implement token-parallel intra-chunk attention#653

Merged
sustcsonglin merged 5 commits intomainfrom
fix-safe-exp
Nov 20, 2025
Merged

[kda kernel optimization] implement token-parallel intra-chunk attention#653
sustcsonglin merged 5 commits intomainfrom
fix-safe-exp

Conversation

@sustcsonglin
Copy link
Copy Markdown
Collaborator

@sustcsonglin sustcsonglin commented Nov 20, 2025

KDA Intra Chunk Attention: Token-Parallel Implementation Benchmark

1. Overview

This report benchmarks the performance of the newly implemented Token-Parallel kernel for KDA intra-chunk attention against the Original (Sub-Chunk) implementation.

The goal was to improve efficiency for large-scale training scenarios by:

  1. Avoiding redundant computation for causal masking (loop bounds respect causality).
  2. Marginalizing the entire K dimension at once (no tiling loop).

2. Methodology

2.1 Function Benchmarked

Target Function: chunk_kda_fwd_intra in fla/ops/kda/chunk_intra.py

Comparison:

  • Original: Uses chunk_kda_fwd_kernel_intra_sub_intra (Sub-Chunk Parallelism).
    • Triggered by: use_token_parallel=False
  • Token-Parallel: Uses chunk_kda_fwd_intra_token_parallel (Token Parallelism).
    • Triggered by: use_token_parallel=True (New Default)

2.2 Environment

  • Device: NVIDIA GPU (CUDA) H200
  • Precision: torch.float16
  • Metric: Average execution time (ms) over 20 runs (after 10 warmups).

2.3 Test Configurations

We selected configurations representative of large-scale LLM training:

Description Batch Size (B) Seq Len (T) Heads (H) Head Dim (K)
Standard Large 16 4096 16 64
Heavy Head Dim 8 4096 32 128
Long Sequence 8 8192 16 64
Very Heavy 4 8192 32 128

3. Performance Results

Configuration Original (ms) Token-Parallel (ms) Speedup Status
Standard Large
B=16, T=4096, H=16, K=64
1.966 1.760 1.12x ✅ Faster
Heavy Head Dim
B=8, T=4096, H=32, K=128
2.829 2.369 1.19x ✅ Faster
Long Sequence
B=8, T=8192, H=16, K=64
1.998 1.772 1.13x ✅ Faster
Very Heavy
B=4, T=8192, H=32, K=128
2.859 2.356 1.21x ✅ Faster

3.1 Key Observations

  1. Consistent Speedup: The Token-Parallel implementation is consistently 12% - 21% faster across all tested large-scale configurations.
  2. Scalability: The speedup increases with model complexity, particularly when Head Dimension K increases to 128 (reaching 1.21x speedup).
  3. Efficiency: The removal of transpose operations and better memory coalescing contributes significantly to the performance gains.

4. Implementation Details

4.1 Original Implementation (Sub-Chunk)

  • Granularity: Sub-Chunk Parallelism (1 thread block per sub-chunk of 16 tokens).
  • Grid: (NT, NC, B*H) where NT is number of chunks, NC is sub-chunks per chunk.
  • K-Dimension: Uses tiling loop over K dimension.
  • Heads: Processes 1 head per thread block.
  • Cons:
    • Overhead from K-dimension tiling loops.

4.2 Token-Parallel Implementation

  • Granularity: Token Parallelism (1 thread block per token).
  • Grid: (Total Tokens, H/BH) where BH is autotuned head block size.
  • K-Dimension: Marginalizes entire K dimension at once (no tiling loop).
  • Heads: Processes BH heads per thread block (Head fusion).
  • Pros:
    • Reduced FLOPs: Loops only up to i_t, skipping upper-triangular computations completely (unlike tiled approach which masks after compute).
    • Reduced Register Pressure: Eliminating K-dimension tiling and simpler control flow allows for potentially higher occupancy.

5. Conclusion

The Token-Parallel implementation provides a significant performance boost for KDA intra-chunk attention, especially in computation-heavy scenarios relevant to large-scale training. It is now set as the default implementation.

  • Recommended Default: use_token_parallel=True
  • Fallback Available: use_token_parallel=False (Original implementation kept for backward compatibility/verification).

- Replace sub_intra kernel with token-parallel implementation for better performance on large sequences.
- Support both fixed-length and variable-length sequences in token-parallel kernel.
- Fix safe_exp usage and add causal masking in log_linear_attn chunk ops.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Nov 20, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a token-parallel execution path to the KDA intra-chunk computation by introducing a conditional flag in chunk_kda_fwd_intra that routes to a new Triton kernel implementation. The new module provides token-parallel kernel logic with support for variable-length sequences.

Changes

Cohort / File(s) Summary
Conditional dispatch in chunk_intra
fla/ops/kda/chunk_intra.py
Added use_token_parallel: bool = True parameter to chunk_kda_fwd_intra function signature; function now conditionally calls chunk_kda_fwd_intra_token_parallel when flag is True, otherwise uses original path.
Token-parallel kernel implementation
fla/ops/kda/chunk_intra_token_parallel.py
New module containing Triton JIT kernel chunk_kda_fwd_kernel_intra_token_parallel for token-parallel KDA computation with variable-length sequence support via binary search over cu_seqlens; includes public wrapper function chunk_kda_fwd_intra_token_parallel with autotune configuration.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant chunk_kda_fwd_intra
    participant chunk_kda_fwd_intra_token_parallel
    participant Triton Kernel

    Caller->>chunk_kda_fwd_intra: call with use_token_parallel flag
    alt use_token_parallel == True
        chunk_kda_fwd_intra->>chunk_kda_fwd_intra_token_parallel: dispatch to wrapper
        chunk_kda_fwd_intra_token_parallel->>Triton Kernel: configure grid and launch
        Triton Kernel->>Triton Kernel: per-token, per-head-group processing
        Triton Kernel->>Triton Kernel: binary search cu_seqlens (if varlen)
        Triton Kernel->>Triton Kernel: compute Aqk, Akk with gating & masking
        Triton Kernel-->>chunk_kda_fwd_intra_token_parallel: return results
    else use_token_parallel == False
        chunk_kda_fwd_intra->>chunk_kda_fwd_intra: use original sub-intra path
    end
    chunk_kda_fwd_intra-->>Caller: return (output, tril_decomposition)
Loading

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Key areas requiring attention:
    • Triton kernel logic in chunk_intra_token_parallel.py: binary search implementation over cu_seqlens, gating mechanics (exp vs exp2), masking logic, and boundary condition handling for variable-length sequences
    • Index bounds safety checks within the kernel and correctness of the per-token accumulation pattern
    • Autotune configuration and heuristics for optimal block/warp selection
    • Integration point in chunk_kda_fwd_intra: verify conditional dispatch logic and parameter forwarding to the new kernel

Possibly Related PRs

  • #621: Introduces the token-parallel kernel that is now being integrated as an optional execution path via the new use_token_parallel flag in chunk_kda_fwd_intra.

Poem

🐰 A token hops through channels bright,
Parallel paths now shine in light,
One block, one head, per token's call—
Triton spins and masks it all!
Fast and fair, the KDA thrives,
Our kernels dance with token lives. 🎯

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately describes the main changes: implementing a token-parallel execution path for intra-chunk KDA attention computation.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix-safe-exp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @sustcsonglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization for KDA (Kernelized Deep Attention) models by implementing a token-parallel strategy for intra-chunk attention computation. This new approach aims to enhance computational efficiency, especially for variable-length sequences, by allowing each token to be processed by its own thread block. Concurrently, the PR cleans up deprecated code related to fused_chunk_gla and refines several existing Triton kernels to improve boundary condition handling and numerical stability, ensuring more robust and accurate attention calculations.

Highlights

  • Token-Parallel Intra-Chunk Attention: Implemented a new token-parallel approach for computing intra-chunk attention in KDA models, designed to improve performance and better handle variable sequence lengths.
  • Deprecation of fused_chunk_gla: The fused_chunk_gla function and its associated kernels have been deprecated and removed, with users directed to use chunk_gla instead.
  • Kernel Refinements and Boundary Handling: Existing Triton kernels for KDA and log-linear attention have been refined to include explicit masking for sequence boundaries and correct application of exponential functions, enhancing numerical stability and correctness.
  • Removal of safe_exp Utility: The safe_exp utility function has been removed, as its functionality is now integrated directly into the kernels using tl.where and tl.exp for more precise control.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant optimization for KDA intra-chunk attention by implementing a token-parallel kernel. This new implementation is added as a default execution path in chunk_kda_fwd_intra, with an option to fall back to the original implementation. The PR also deprecates fused_chunk_gla, guiding users towards chunk_gla. Additionally, there are several correctness fixes and refactorings across different kernels, such as adding explicit masking for padded tokens and replacing the implicit causal masking of safe_exp with explicit tl.where conditions, which improves code clarity and robustness.

Overall, the changes are well-implemented and contribute to better performance and maintainability. I have one minor suggestion to improve the documentation for the new token-parallel kernel.

Comment thread fla/ops/kda/chunk_intra_token_parallel.py Outdated
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4dc9687 and c29719b.

📒 Files selected for processing (2)
  • fla/ops/kda/chunk_intra.py (4 hunks)
  • fla/ops/kda/chunk_intra_token_parallel.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
fla/ops/kda/chunk_intra.py (1)
fla/ops/kda/chunk_intra_token_parallel.py (2)
  • chunk_kda_fwd_intra_token_parallel (158-219)
  • grid (200-202)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops

Comment on lines +192 to +198
if cu_seqlens is not None:
total_tokens = q.shape[1]
# Use num_sequences as B for binary search
B_kernel = len(cu_seqlens) - 1
else:
total_tokens = B * T
B_kernel = B
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix varlen token count for token-parallel grid.

Line 193

When cu_seqlens is provided, total_tokens must come from the cumulative lengths (cu_seqlens[-1]). Using q.shape[1] only covers the padded T dimension, so in a batch with several sequences we launch work for at most T_max tokens and leave the remainder of the batch untouched (their Aqk/Akk rows stay zero). Please derive the grid size from the cumulative lengths instead.

-        total_tokens = q.shape[1]
+        total_tokens = int(cu_seqlens[-1].item())
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if cu_seqlens is not None:
total_tokens = q.shape[1]
# Use num_sequences as B for binary search
B_kernel = len(cu_seqlens) - 1
else:
total_tokens = B * T
B_kernel = B
if cu_seqlens is not None:
total_tokens = int(cu_seqlens[-1].item())
# Use num_sequences as B for binary search
B_kernel = len(cu_seqlens) - 1
else:
total_tokens = B * T
B_kernel = B
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_intra_token_parallel.py around lines 192 to 198, when
cu_seqlens is provided total_tokens is incorrectly set from q.shape[1] (padded
T) which undercounts real tokens for varlen batches; replace total_tokens =
q.shape[1] with total_tokens = int(cu_seqlens[-1]) (or cu_seqlens[-1].item() if
it's a tensor) so the kernel grid uses the true total token count derived from
cumulative lengths while keeping B_kernel = len(cu_seqlens) - 1.

@sustcsonglin sustcsonglin merged commit b77fa00 into main Nov 20, 2025
6 checks passed
@sustcsonglin sustcsonglin deleted the fix-safe-exp branch November 20, 2025 07:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant